Skip to main content

client_example/
client_example.rs

1use ndarray::ArrayD;
2use std::collections::HashMap;
3
4use philote_mdo::{client::ExplicitClient, types::StreamOptions, ArrayMap, PhiloteError, Result};
5
6#[tokio::main]
7async fn main() -> Result<()> {
8    println!("🚀 Philote Rust Client Example");
9
10    // This example demonstrates how to use the ExplicitClient
11    // Note: This would connect to a running Philote server
12
13    let server_address = "http://localhost:50051";
14    println!("📡 Connecting to server at: {}", server_address);
15
16    // Try to connect to the server
17    let mut client = match ExplicitClient::connect(server_address).await {
18        Ok(client) => {
19            println!("✅ Connected successfully!");
20            client
21        }
22        Err(e) => {
23            println!("❌ Failed to connect to server: {}", e);
24            println!("💡 To run this example, you need a Philote server running.");
25            println!("   You can start one using the paraboloid example as a server.");
26            return Err(e);
27        }
28    };
29
30    // Configure streaming options
31    let stream_options = StreamOptions {
32        max_double_per_slice: 1000,
33    };
34    client = client.with_stream_options(stream_options);
35
36    println!("⚙️ Getting discipline information...");
37
38    // Get discipline properties
39    match client.get_info().await {
40        Ok(properties) => {
41            println!("📋 Discipline Properties:");
42            println!("   - Name: {}", properties.name);
43            println!("   - Version: {}", properties.version);
44            println!("   - Continuous: {}", properties.continuous);
45            println!("   - Differentiable: {}", properties.differentiable);
46            println!("   - Provides Gradients: {}", properties.provides_gradients);
47        }
48        Err(e) => {
49            println!("❌ Failed to get discipline info: {}", e);
50            return Err(e);
51        }
52    }
53
54    // Get available options
55    match client.get_available_options().await {
56        Ok(options) => {
57            println!("🔧 Available Options:");
58            for (name, type_str) in options {
59                println!("   - {}: {}", name, type_str);
60            }
61        }
62        Err(e) => {
63            println!("❌ Failed to get available options: {}", e);
64        }
65    }
66
67    // Setup the discipline
68    println!("🔄 Setting up discipline...");
69    if let Err(e) = client.setup().await {
70        println!("❌ Failed to setup discipline: {}", e);
71        return Err(e);
72    }
73
74    // Get variable definitions
75    match client.get_variable_definitions().await {
76        Ok(variables) => {
77            println!("📊 Variable Definitions:");
78            for var in variables {
79                println!("   - {}: {:?} ({})", var.name, var.shape, var.units);
80            }
81        }
82        Err(e) => {
83            println!("❌ Failed to get variable definitions: {}", e);
84        }
85    }
86
87    // Get partial definitions (if available)
88    match client.get_partial_definitions().await {
89        Ok(partials) => {
90            println!("📈 Partial Definitions:");
91            for partial in partials {
92                println!(
93                    "   - {}/{}: {:?}",
94                    partial.name, partial.subname, partial.shape
95                );
96            }
97        }
98        Err(e) => {
99            println!("❌ Failed to get partial definitions: {}", e);
100        }
101    }
102
103    // Create input arrays for computation
104    let mut inputs = HashMap::new();
105    inputs.insert("x".to_string(), ArrayD::from_elem(vec![1], 2.0));
106    inputs.insert("y".to_string(), ArrayD::from_elem(vec![1], -1.0));
107
108    println!("🧮 Computing function with inputs: x=2.0, y=-1.0");
109
110    // Call compute function
111    match client.compute_function(&inputs).await {
112        Ok(outputs) => {
113            println!("✅ Computation successful!");
114            for (name, array) in outputs {
115                println!("   Output {}: {}", name, array[[0]]);
116            }
117
118            // Also compute gradients
119            println!("📊 Computing gradients...");
120            match client.compute_gradient(&inputs).await {
121                Ok(partials) => {
122                    println!("✅ Gradient computation successful!");
123                    for ((func, var), array) in partials {
124                        println!("   ∂{}/∂{}: {}", func, var, array[[0]]);
125                    }
126                }
127                Err(e) => {
128                    println!("❌ Failed to compute gradients: {}", e);
129                }
130            }
131        }
132        Err(e) => {
133            println!("❌ Failed to compute function: {}", e);
134            return Err(e);
135        }
136    }
137
138    println!("🎉 Client example completed successfully!");
139
140    Ok(())
141}
142
143// Example of how to start a server for testing
144pub async fn start_test_server() -> Result<()> {
145    use philote_mdo::{
146        philote_info::{VariableMetaData, VariableType},
147        server::ExplicitServer,
148        traits::{Discipline, ExplicitDiscipline},
149    };
150    use std::net::SocketAddr;
151    use tonic::transport::Server;
152
153    // This is a simplified version of our paraboloid discipline
154    struct TestParaboloid {
155        variables: Vec<VariableMetaData>,
156        partials: Vec<(String, String)>,
157        options: HashMap<String, String>,
158    }
159
160    impl TestParaboloid {
161        fn new() -> Self {
162            Self {
163                variables: Vec::new(),
164                partials: Vec::new(),
165                options: HashMap::new(),
166            }
167        }
168    }
169
170    impl Discipline for TestParaboloid {
171        fn name(&self) -> &str {
172            "TestParaboloid"
173        }
174        fn version(&self) -> &str {
175            "1.0.0"
176        }
177        fn is_continuous(&self) -> bool {
178            true
179        }
180        fn is_differentiable(&self) -> bool {
181            true
182        }
183        fn provides_gradients(&self) -> bool {
184            true
185        }
186
187        fn add_input(&mut self, name: &str, shape: &[usize], units: &str) -> Result<()> {
188            let var_meta = VariableMetaData {
189                r#type: VariableType::KInput as i32,
190                name: name.to_string(),
191                shape: shape.iter().map(|&s| s as i64).collect(),
192                units: units.to_string(),
193                dynamic_shape: false,
194            };
195            self.variables.push(var_meta);
196            Ok(())
197        }
198
199        fn add_output(&mut self, name: &str, shape: &[usize], units: &str) -> Result<()> {
200            let var_meta = VariableMetaData {
201                r#type: VariableType::KOutput as i32,
202                name: name.to_string(),
203                shape: shape.iter().map(|&s| s as i64).collect(),
204                units: units.to_string(),
205                dynamic_shape: false,
206            };
207            self.variables.push(var_meta);
208            Ok(())
209        }
210
211        fn add_option(&mut self, name: &str, option_type: &str) -> Result<()> {
212            self.options
213                .insert(name.to_string(), option_type.to_string());
214            Ok(())
215        }
216
217        fn set_options(&mut self, _options: &HashMap<String, serde_json::Value>) -> Result<()> {
218            Ok(())
219        }
220
221        fn setup(&mut self) -> Result<()> {
222            self.variables.clear();
223            self.add_input("x", &[1], "")?;
224            self.add_input("y", &[1], "")?;
225            self.add_output("f", &[1], "")?;
226            Ok(())
227        }
228
229        fn declare_partials(&mut self, func: &str, var: &str) -> Result<()> {
230            self.partials.push((func.to_string(), var.to_string()));
231            Ok(())
232        }
233
234        fn setup_partials(&mut self) -> Result<()> {
235            self.declare_partials("f", "x")?;
236            self.declare_partials("f", "y")?;
237            Ok(())
238        }
239
240        fn get_variable_definitions(&self) -> Result<Vec<VariableMetaData>> {
241            Ok(self.variables.clone())
242        }
243
244        fn get_partials_definitions(&self) -> Result<Vec<(String, String)>> {
245            Ok(self.partials.clone())
246        }
247
248        fn get_available_options(&self) -> Result<HashMap<String, String>> {
249            Ok(self.options.clone())
250        }
251    }
252
253    #[async_trait::async_trait]
254    impl ExplicitDiscipline for TestParaboloid {
255        async fn compute(&self, inputs: &ArrayMap) -> Result<ArrayMap> {
256            let x = inputs
257                .get("x")
258                .ok_or_else(|| PhiloteError::VariableNotFound("x".to_string()))?;
259            let y = inputs
260                .get("y")
261                .ok_or_else(|| PhiloteError::VariableNotFound("y".to_string()))?;
262
263            let x_val = x[[0]];
264            let y_val = y[[0]];
265            let f_val = (x_val - 3.0).powi(2) + x_val * y_val + (y_val + 4.0).powi(2) - 3.0;
266
267            let mut outputs = HashMap::new();
268            outputs.insert("f".to_string(), ArrayD::from_elem(vec![1], f_val));
269            Ok(outputs)
270        }
271
272        async fn compute_partials(&self, inputs: &ArrayMap) -> Result<philote_mdo::PartialMap> {
273            let x = inputs
274                .get("x")
275                .ok_or_else(|| PhiloteError::VariableNotFound("x".to_string()))?;
276            let y = inputs
277                .get("y")
278                .ok_or_else(|| PhiloteError::VariableNotFound("y".to_string()))?;
279
280            let x_val = x[[0]];
281            let y_val = y[[0]];
282
283            let df_dx = 2.0 * (x_val - 3.0) + y_val;
284            let df_dy = x_val + 2.0 * (y_val + 4.0);
285
286            let mut partials = HashMap::new();
287            partials.insert(
288                ("f".to_string(), "x".to_string()),
289                ArrayD::from_elem(vec![1], df_dx),
290            );
291            partials.insert(
292                ("f".to_string(), "y".to_string()),
293                ArrayD::from_elem(vec![1], df_dy),
294            );
295            Ok(partials)
296        }
297    }
298
299    let mut discipline = TestParaboloid::new();
300    discipline.setup()?;
301    discipline.setup_partials()?;
302
303    let server_impl = ExplicitServer::new(discipline).with_verbose(true);
304
305    let addr: SocketAddr = "127.0.0.1:50051"
306        .parse()
307        .map_err(|e| PhiloteError::config_error(format!("Invalid address: {}", e)))?;
308
309    println!("🚀 Starting test server on {}", addr);
310
311    Server::builder()
312        .add_service(
313            philote_mdo::philote_info::explicit_service_server::ExplicitServiceServer::new(
314                server_impl,
315            ),
316        )
317        .serve(addr)
318        .await
319        .map_err(|e| PhiloteError::config_error(format!("Server error: {}", e)))?;
320
321    Ok(())
322}