Skip to main content

paraboloid/
paraboloid.rs

1use async_trait::async_trait;
2use ndarray::ArrayD;
3use std::collections::HashMap;
4
5use philote_mdo::{
6    philote_info::{VariableMetaData, VariableType},
7    server::ExplicitServer,
8    traits::{Discipline, ExplicitDiscipline},
9    ArrayMap, PartialMap, PhiloteError, Result,
10};
11
12pub struct Paraboloid {
13    variables: Vec<VariableMetaData>,
14    partials: Vec<(String, String)>,
15    options: HashMap<String, String>,
16}
17
18impl Paraboloid {
19    pub fn new() -> Self {
20        Self {
21            variables: Vec::new(),
22            partials: Vec::new(),
23            options: HashMap::new(),
24        }
25    }
26}
27
28impl Discipline for Paraboloid {
29    fn name(&self) -> &str {
30        "Paraboloid"
31    }
32
33    fn version(&self) -> &str {
34        "1.0.0"
35    }
36
37    fn is_continuous(&self) -> bool {
38        true
39    }
40
41    fn is_differentiable(&self) -> bool {
42        true
43    }
44
45    fn provides_gradients(&self) -> bool {
46        true
47    }
48
49    fn initialize(&mut self) -> Result<()> {
50        self.add_option("a", "double")?;
51        self.add_option("b", "double")?;
52        Ok(())
53    }
54
55    fn add_input(&mut self, name: &str, shape: &[usize], units: &str) -> Result<()> {
56        let var_meta = VariableMetaData {
57            r#type: VariableType::KInput as i32,
58            name: name.to_string(),
59            shape: shape.iter().map(|&s| s as i64).collect(),
60            units: units.to_string(),
61            dynamic_shape: false,
62        };
63        self.variables.push(var_meta);
64        Ok(())
65    }
66
67    fn add_output(&mut self, name: &str, shape: &[usize], units: &str) -> Result<()> {
68        let var_meta = VariableMetaData {
69            r#type: VariableType::KOutput as i32,
70            name: name.to_string(),
71            shape: shape.iter().map(|&s| s as i64).collect(),
72            units: units.to_string(),
73            dynamic_shape: false,
74        };
75        self.variables.push(var_meta);
76        Ok(())
77    }
78
79    fn add_option(&mut self, name: &str, option_type: &str) -> Result<()> {
80        self.options
81            .insert(name.to_string(), option_type.to_string());
82        Ok(())
83    }
84
85    fn set_options(&mut self, _options: &HashMap<String, serde_json::Value>) -> Result<()> {
86        // For this example, we'll use default values
87        Ok(())
88    }
89
90    fn setup(&mut self) -> Result<()> {
91        // Clear existing variables
92        self.variables.clear();
93
94        // Add inputs
95        self.add_input("x", &[1], "")?;
96        self.add_input("y", &[1], "")?;
97
98        // Add output
99        self.add_output("f", &[1], "")?;
100
101        Ok(())
102    }
103
104    fn declare_partials(&mut self, func: &str, var: &str) -> Result<()> {
105        self.partials.push((func.to_string(), var.to_string()));
106        Ok(())
107    }
108
109    fn setup_partials(&mut self) -> Result<()> {
110        self.declare_partials("f", "x")?;
111        self.declare_partials("f", "y")?;
112        Ok(())
113    }
114
115    fn get_variable_definitions(&self) -> Result<Vec<VariableMetaData>> {
116        Ok(self.variables.clone())
117    }
118
119    fn get_partials_definitions(&self) -> Result<Vec<(String, String)>> {
120        Ok(self.partials.clone())
121    }
122
123    fn get_available_options(&self) -> Result<HashMap<String, String>> {
124        Ok(self.options.clone())
125    }
126}
127
128#[async_trait]
129impl ExplicitDiscipline for Paraboloid {
130    async fn compute(&self, inputs: &ArrayMap) -> Result<ArrayMap> {
131        let x = inputs
132            .get("x")
133            .ok_or_else(|| PhiloteError::VariableNotFound("x".to_string()))?;
134
135        let y = inputs
136            .get("y")
137            .ok_or_else(|| PhiloteError::VariableNotFound("y".to_string()))?;
138
139        if x.len() != 1 || y.len() != 1 {
140            return Err(PhiloteError::array_error("Expected scalar inputs"));
141        }
142
143        let x_val = x[[0]];
144        let y_val = y[[0]];
145
146        // f = (x - 3)^2 + x*y + (y + 4)^2 - 3
147        let f_val = (x_val - 3.0).powi(2) + x_val * y_val + (y_val + 4.0).powi(2) - 3.0;
148
149        let mut outputs = HashMap::new();
150        let f_array = ArrayD::from_elem(vec![1], f_val);
151        outputs.insert("f".to_string(), f_array);
152
153        Ok(outputs)
154    }
155
156    async fn compute_partials(&self, inputs: &ArrayMap) -> Result<PartialMap> {
157        let x = inputs
158            .get("x")
159            .ok_or_else(|| PhiloteError::VariableNotFound("x".to_string()))?;
160
161        let y = inputs
162            .get("y")
163            .ok_or_else(|| PhiloteError::VariableNotFound("y".to_string()))?;
164
165        if x.len() != 1 || y.len() != 1 {
166            return Err(PhiloteError::array_error("Expected scalar inputs"));
167        }
168
169        let x_val = x[[0]];
170        let y_val = y[[0]];
171
172        // df/dx = 2*(x - 3) + y
173        let df_dx = 2.0 * (x_val - 3.0) + y_val;
174
175        // df/dy = x + 2*(y + 4)
176        let df_dy = x_val + 2.0 * (y_val + 4.0);
177
178        let mut partials = HashMap::new();
179        partials.insert(
180            ("f".to_string(), "x".to_string()),
181            ArrayD::from_elem(vec![1], df_dx),
182        );
183        partials.insert(
184            ("f".to_string(), "y".to_string()),
185            ArrayD::from_elem(vec![1], df_dy),
186        );
187
188        Ok(partials)
189    }
190}
191
192#[tokio::main]
193async fn main() -> Result<()> {
194    // Create the discipline
195    let mut paraboloid = Paraboloid::new();
196    paraboloid.initialize()?;
197    paraboloid.setup()?;
198    paraboloid.setup_partials()?;
199
200    // Create the server
201    let server = ExplicitServer::new(paraboloid).with_verbose(true);
202
203    println!("✅ Paraboloid discipline server created successfully!");
204    println!("🚀 The Rust port of Philote-MDO is working!");
205
206    // Test the discipline directly
207    let mut inputs = HashMap::new();
208    inputs.insert("x".to_string(), ArrayD::from_elem(vec![1], 2.0));
209    inputs.insert("y".to_string(), ArrayD::from_elem(vec![1], -1.0));
210
211    let discipline = server.discipline().read().await;
212    let outputs = discipline.compute(&inputs).await?;
213    let partials = discipline.compute_partials(&inputs).await?;
214
215    println!("\nTest computation:");
216    println!("Inputs: x = {}, y = {}", inputs["x"][[0]], inputs["y"][[0]]);
217    println!("Output: f = {}", outputs["f"][[0]]);
218    println!(
219        "Partials: df/dx = {}, df/dy = {}",
220        partials[&("f".to_string(), "x".to_string())][[0]],
221        partials[&("f".to_string(), "y".to_string())][[0]]
222    );
223
224    Ok(())
225}