1use ndarray::ArrayD;
2use std::collections::HashMap;
3
4use philote_mdo::{client::ExplicitClient, types::StreamOptions, ArrayMap, PhiloteError, Result};
5
6#[tokio::main]
7async fn main() -> Result<()> {
8 println!("🚀 Philote Rust Client Example");
9
10 let server_address = "http://localhost:50051";
14 println!("📡 Connecting to server at: {}", server_address);
15
16 let mut client = match ExplicitClient::connect(server_address).await {
18 Ok(client) => {
19 println!("✅ Connected successfully!");
20 client
21 }
22 Err(e) => {
23 println!("❌ Failed to connect to server: {}", e);
24 println!("💡 To run this example, you need a Philote server running.");
25 println!(" You can start one using the paraboloid example as a server.");
26 return Err(e);
27 }
28 };
29
30 let stream_options = StreamOptions {
32 max_double_per_slice: 1000,
33 };
34 client = client.with_stream_options(stream_options);
35
36 println!("⚙️ Getting discipline information...");
37
38 match client.get_info().await {
40 Ok(properties) => {
41 println!("📋 Discipline Properties:");
42 println!(" - Name: {}", properties.name);
43 println!(" - Version: {}", properties.version);
44 println!(" - Continuous: {}", properties.continuous);
45 println!(" - Differentiable: {}", properties.differentiable);
46 println!(" - Provides Gradients: {}", properties.provides_gradients);
47 }
48 Err(e) => {
49 println!("❌ Failed to get discipline info: {}", e);
50 return Err(e);
51 }
52 }
53
54 match client.get_available_options().await {
56 Ok(options) => {
57 println!("🔧 Available Options:");
58 for (name, type_str) in options {
59 println!(" - {}: {}", name, type_str);
60 }
61 }
62 Err(e) => {
63 println!("❌ Failed to get available options: {}", e);
64 }
65 }
66
67 println!("🔄 Setting up discipline...");
69 if let Err(e) = client.setup().await {
70 println!("❌ Failed to setup discipline: {}", e);
71 return Err(e);
72 }
73
74 match client.get_variable_definitions().await {
76 Ok(variables) => {
77 println!("📊 Variable Definitions:");
78 for var in variables {
79 println!(" - {}: {:?} ({})", var.name, var.shape, var.units);
80 }
81 }
82 Err(e) => {
83 println!("❌ Failed to get variable definitions: {}", e);
84 }
85 }
86
87 match client.get_partial_definitions().await {
89 Ok(partials) => {
90 println!("📈 Partial Definitions:");
91 for partial in partials {
92 println!(
93 " - {}/{}: {:?}",
94 partial.name, partial.subname, partial.shape
95 );
96 }
97 }
98 Err(e) => {
99 println!("❌ Failed to get partial definitions: {}", e);
100 }
101 }
102
103 let mut inputs = HashMap::new();
105 inputs.insert("x".to_string(), ArrayD::from_elem(vec![1], 2.0));
106 inputs.insert("y".to_string(), ArrayD::from_elem(vec![1], -1.0));
107
108 println!("🧮 Computing function with inputs: x=2.0, y=-1.0");
109
110 match client.compute_function(&inputs).await {
112 Ok(outputs) => {
113 println!("✅ Computation successful!");
114 for (name, array) in outputs {
115 println!(" Output {}: {}", name, array[[0]]);
116 }
117
118 println!("📊 Computing gradients...");
120 match client.compute_gradient(&inputs).await {
121 Ok(partials) => {
122 println!("✅ Gradient computation successful!");
123 for ((func, var), array) in partials {
124 println!(" ∂{}/∂{}: {}", func, var, array[[0]]);
125 }
126 }
127 Err(e) => {
128 println!("❌ Failed to compute gradients: {}", e);
129 }
130 }
131 }
132 Err(e) => {
133 println!("❌ Failed to compute function: {}", e);
134 return Err(e);
135 }
136 }
137
138 println!("🎉 Client example completed successfully!");
139
140 Ok(())
141}
142
143pub async fn start_test_server() -> Result<()> {
145 use philote_mdo::{
146 philote_info::{VariableMetaData, VariableType},
147 server::ExplicitServer,
148 traits::{Discipline, ExplicitDiscipline},
149 };
150 use std::net::SocketAddr;
151 use tonic::transport::Server;
152
153 struct TestParaboloid {
155 variables: Vec<VariableMetaData>,
156 partials: Vec<(String, String)>,
157 options: HashMap<String, String>,
158 }
159
160 impl TestParaboloid {
161 fn new() -> Self {
162 Self {
163 variables: Vec::new(),
164 partials: Vec::new(),
165 options: HashMap::new(),
166 }
167 }
168 }
169
170 impl Discipline for TestParaboloid {
171 fn name(&self) -> &str {
172 "TestParaboloid"
173 }
174 fn version(&self) -> &str {
175 "1.0.0"
176 }
177 fn is_continuous(&self) -> bool {
178 true
179 }
180 fn is_differentiable(&self) -> bool {
181 true
182 }
183 fn provides_gradients(&self) -> bool {
184 true
185 }
186
187 fn add_input(&mut self, name: &str, shape: &[usize], units: &str) -> Result<()> {
188 let var_meta = VariableMetaData {
189 r#type: VariableType::KInput as i32,
190 name: name.to_string(),
191 shape: shape.iter().map(|&s| s as i64).collect(),
192 units: units.to_string(),
193 dynamic_shape: false,
194 };
195 self.variables.push(var_meta);
196 Ok(())
197 }
198
199 fn add_output(&mut self, name: &str, shape: &[usize], units: &str) -> Result<()> {
200 let var_meta = VariableMetaData {
201 r#type: VariableType::KOutput as i32,
202 name: name.to_string(),
203 shape: shape.iter().map(|&s| s as i64).collect(),
204 units: units.to_string(),
205 dynamic_shape: false,
206 };
207 self.variables.push(var_meta);
208 Ok(())
209 }
210
211 fn add_option(&mut self, name: &str, option_type: &str) -> Result<()> {
212 self.options
213 .insert(name.to_string(), option_type.to_string());
214 Ok(())
215 }
216
217 fn set_options(&mut self, _options: &HashMap<String, serde_json::Value>) -> Result<()> {
218 Ok(())
219 }
220
221 fn setup(&mut self) -> Result<()> {
222 self.variables.clear();
223 self.add_input("x", &[1], "")?;
224 self.add_input("y", &[1], "")?;
225 self.add_output("f", &[1], "")?;
226 Ok(())
227 }
228
229 fn declare_partials(&mut self, func: &str, var: &str) -> Result<()> {
230 self.partials.push((func.to_string(), var.to_string()));
231 Ok(())
232 }
233
234 fn setup_partials(&mut self) -> Result<()> {
235 self.declare_partials("f", "x")?;
236 self.declare_partials("f", "y")?;
237 Ok(())
238 }
239
240 fn get_variable_definitions(&self) -> Result<Vec<VariableMetaData>> {
241 Ok(self.variables.clone())
242 }
243
244 fn get_partials_definitions(&self) -> Result<Vec<(String, String)>> {
245 Ok(self.partials.clone())
246 }
247
248 fn get_available_options(&self) -> Result<HashMap<String, String>> {
249 Ok(self.options.clone())
250 }
251 }
252
253 #[async_trait::async_trait]
254 impl ExplicitDiscipline for TestParaboloid {
255 async fn compute(&self, inputs: &ArrayMap) -> Result<ArrayMap> {
256 let x = inputs
257 .get("x")
258 .ok_or_else(|| PhiloteError::VariableNotFound("x".to_string()))?;
259 let y = inputs
260 .get("y")
261 .ok_or_else(|| PhiloteError::VariableNotFound("y".to_string()))?;
262
263 let x_val = x[[0]];
264 let y_val = y[[0]];
265 let f_val = (x_val - 3.0).powi(2) + x_val * y_val + (y_val + 4.0).powi(2) - 3.0;
266
267 let mut outputs = HashMap::new();
268 outputs.insert("f".to_string(), ArrayD::from_elem(vec![1], f_val));
269 Ok(outputs)
270 }
271
272 async fn compute_partials(&self, inputs: &ArrayMap) -> Result<philote_mdo::PartialMap> {
273 let x = inputs
274 .get("x")
275 .ok_or_else(|| PhiloteError::VariableNotFound("x".to_string()))?;
276 let y = inputs
277 .get("y")
278 .ok_or_else(|| PhiloteError::VariableNotFound("y".to_string()))?;
279
280 let x_val = x[[0]];
281 let y_val = y[[0]];
282
283 let df_dx = 2.0 * (x_val - 3.0) + y_val;
284 let df_dy = x_val + 2.0 * (y_val + 4.0);
285
286 let mut partials = HashMap::new();
287 partials.insert(
288 ("f".to_string(), "x".to_string()),
289 ArrayD::from_elem(vec![1], df_dx),
290 );
291 partials.insert(
292 ("f".to_string(), "y".to_string()),
293 ArrayD::from_elem(vec![1], df_dy),
294 );
295 Ok(partials)
296 }
297 }
298
299 let mut discipline = TestParaboloid::new();
300 discipline.setup()?;
301 discipline.setup_partials()?;
302
303 let server_impl = ExplicitServer::new(discipline).with_verbose(true);
304
305 let addr: SocketAddr = "127.0.0.1:50051"
306 .parse()
307 .map_err(|e| PhiloteError::config_error(format!("Invalid address: {}", e)))?;
308
309 println!("🚀 Starting test server on {}", addr);
310
311 Server::builder()
312 .add_service(
313 philote_mdo::philote_info::explicit_service_server::ExplicitServiceServer::new(
314 server_impl,
315 ),
316 )
317 .serve(addr)
318 .await
319 .map_err(|e| PhiloteError::config_error(format!("Server error: {}", e)))?;
320
321 Ok(())
322}