1use async_trait::async_trait;
2use ndarray::ArrayD;
3use std::collections::HashMap;
4
5use philote_mdo::{
6 philote_info::{VariableMetaData, VariableType},
7 server::ExplicitServer,
8 traits::{Discipline, ExplicitDiscipline},
9 ArrayMap, PartialMap, PhiloteError, Result,
10};
11
12pub struct Paraboloid {
13 variables: Vec<VariableMetaData>,
14 partials: Vec<(String, String)>,
15 options: HashMap<String, String>,
16}
17
18impl Paraboloid {
19 pub fn new() -> Self {
20 Self {
21 variables: Vec::new(),
22 partials: Vec::new(),
23 options: HashMap::new(),
24 }
25 }
26}
27
28impl Discipline for Paraboloid {
29 fn name(&self) -> &str {
30 "Paraboloid"
31 }
32
33 fn version(&self) -> &str {
34 "1.0.0"
35 }
36
37 fn is_continuous(&self) -> bool {
38 true
39 }
40
41 fn is_differentiable(&self) -> bool {
42 true
43 }
44
45 fn provides_gradients(&self) -> bool {
46 true
47 }
48
49 fn initialize(&mut self) -> Result<()> {
50 self.add_option("a", "double")?;
51 self.add_option("b", "double")?;
52 Ok(())
53 }
54
55 fn add_input(&mut self, name: &str, shape: &[usize], units: &str) -> Result<()> {
56 let var_meta = VariableMetaData {
57 r#type: VariableType::KInput as i32,
58 name: name.to_string(),
59 shape: shape.iter().map(|&s| s as i64).collect(),
60 units: units.to_string(),
61 dynamic_shape: false,
62 };
63 self.variables.push(var_meta);
64 Ok(())
65 }
66
67 fn add_output(&mut self, name: &str, shape: &[usize], units: &str) -> Result<()> {
68 let var_meta = VariableMetaData {
69 r#type: VariableType::KOutput as i32,
70 name: name.to_string(),
71 shape: shape.iter().map(|&s| s as i64).collect(),
72 units: units.to_string(),
73 dynamic_shape: false,
74 };
75 self.variables.push(var_meta);
76 Ok(())
77 }
78
79 fn add_option(&mut self, name: &str, option_type: &str) -> Result<()> {
80 self.options
81 .insert(name.to_string(), option_type.to_string());
82 Ok(())
83 }
84
85 fn set_options(&mut self, _options: &HashMap<String, serde_json::Value>) -> Result<()> {
86 Ok(())
88 }
89
90 fn setup(&mut self) -> Result<()> {
91 self.variables.clear();
93
94 self.add_input("x", &[1], "")?;
96 self.add_input("y", &[1], "")?;
97
98 self.add_output("f", &[1], "")?;
100
101 Ok(())
102 }
103
104 fn declare_partials(&mut self, func: &str, var: &str) -> Result<()> {
105 self.partials.push((func.to_string(), var.to_string()));
106 Ok(())
107 }
108
109 fn setup_partials(&mut self) -> Result<()> {
110 self.declare_partials("f", "x")?;
111 self.declare_partials("f", "y")?;
112 Ok(())
113 }
114
115 fn get_variable_definitions(&self) -> Result<Vec<VariableMetaData>> {
116 Ok(self.variables.clone())
117 }
118
119 fn get_partials_definitions(&self) -> Result<Vec<(String, String)>> {
120 Ok(self.partials.clone())
121 }
122
123 fn get_available_options(&self) -> Result<HashMap<String, String>> {
124 Ok(self.options.clone())
125 }
126}
127
128#[async_trait]
129impl ExplicitDiscipline for Paraboloid {
130 async fn compute(&self, inputs: &ArrayMap) -> Result<ArrayMap> {
131 let x = inputs
132 .get("x")
133 .ok_or_else(|| PhiloteError::VariableNotFound("x".to_string()))?;
134
135 let y = inputs
136 .get("y")
137 .ok_or_else(|| PhiloteError::VariableNotFound("y".to_string()))?;
138
139 if x.len() != 1 || y.len() != 1 {
140 return Err(PhiloteError::array_error("Expected scalar inputs"));
141 }
142
143 let x_val = x[[0]];
144 let y_val = y[[0]];
145
146 let f_val = (x_val - 3.0).powi(2) + x_val * y_val + (y_val + 4.0).powi(2) - 3.0;
148
149 let mut outputs = HashMap::new();
150 let f_array = ArrayD::from_elem(vec![1], f_val);
151 outputs.insert("f".to_string(), f_array);
152
153 Ok(outputs)
154 }
155
156 async fn compute_partials(&self, inputs: &ArrayMap) -> Result<PartialMap> {
157 let x = inputs
158 .get("x")
159 .ok_or_else(|| PhiloteError::VariableNotFound("x".to_string()))?;
160
161 let y = inputs
162 .get("y")
163 .ok_or_else(|| PhiloteError::VariableNotFound("y".to_string()))?;
164
165 if x.len() != 1 || y.len() != 1 {
166 return Err(PhiloteError::array_error("Expected scalar inputs"));
167 }
168
169 let x_val = x[[0]];
170 let y_val = y[[0]];
171
172 let df_dx = 2.0 * (x_val - 3.0) + y_val;
174
175 let df_dy = x_val + 2.0 * (y_val + 4.0);
177
178 let mut partials = HashMap::new();
179 partials.insert(
180 ("f".to_string(), "x".to_string()),
181 ArrayD::from_elem(vec![1], df_dx),
182 );
183 partials.insert(
184 ("f".to_string(), "y".to_string()),
185 ArrayD::from_elem(vec![1], df_dy),
186 );
187
188 Ok(partials)
189 }
190}
191
192#[tokio::main]
193async fn main() -> Result<()> {
194 let mut paraboloid = Paraboloid::new();
196 paraboloid.initialize()?;
197 paraboloid.setup()?;
198 paraboloid.setup_partials()?;
199
200 let server = ExplicitServer::new(paraboloid).with_verbose(true);
202
203 println!("✅ Paraboloid discipline server created successfully!");
204 println!("🚀 The Rust port of Philote-MDO is working!");
205
206 let mut inputs = HashMap::new();
208 inputs.insert("x".to_string(), ArrayD::from_elem(vec![1], 2.0));
209 inputs.insert("y".to_string(), ArrayD::from_elem(vec![1], -1.0));
210
211 let discipline = server.discipline().read().await;
212 let outputs = discipline.compute(&inputs).await?;
213 let partials = discipline.compute_partials(&inputs).await?;
214
215 println!("\nTest computation:");
216 println!("Inputs: x = {}, y = {}", inputs["x"][[0]], inputs["y"][[0]]);
217 println!("Output: f = {}", outputs["f"][[0]]);
218 println!(
219 "Partials: df/dx = {}, df/dy = {}",
220 partials[&("f".to_string(), "x".to_string())][[0]],
221 partials[&("f".to_string(), "y".to_string())][[0]]
222 );
223
224 Ok(())
225}