Skip to main content

philote_mdo/
types.rs

1//! Data types and Protocol Buffer conversions
2//!
3//! This module provides Rust data structures for working with discipline variables
4//! and arrays, along with conversions to/from Protocol Buffer messages.
5//!
6//! # Key Types
7//!
8//! - [`VariableData`] - Complete variable with name, data array, units, and type
9//! - [`ArrayData`] - Chunked array data for streaming transmission
10//! - [`StreamOptions`] - Configuration for array streaming behavior
11//! - [`ArrayChunker`] - Utility for splitting large arrays into chunks
12//!
13//! # Protocol Buffer Conversion
14//!
15//! Types implement `From` and `TryFrom` traits for seamless conversion between
16//! Rust structures and Protocol Buffer messages, enabling efficient gRPC communication.
17
18use ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
19
20use crate::philote_info::{Array, PartialsMetaData, VariableType};
21use crate::{PhiloteError, Result};
22
23#[derive(Debug, Clone)]
24pub struct VariableData {
25    pub name: String,
26    pub data: ArrayD<f64>,
27    pub units: String,
28    pub var_type: VariableType,
29}
30
31impl VariableData {
32    pub fn new(name: String, data: ArrayD<f64>, units: String, var_type: VariableType) -> Self {
33        Self {
34            name,
35            data,
36            units,
37            var_type,
38        }
39    }
40
41    pub fn zeros(name: String, shape: &[usize], units: String, var_type: VariableType) -> Self {
42        let data = ArrayD::zeros(shape);
43        Self::new(name, data, units, var_type)
44    }
45
46    pub fn shape(&self) -> &[usize] {
47        self.data.shape()
48    }
49
50    pub fn size(&self) -> usize {
51        self.data.len()
52    }
53
54    pub fn view(&self) -> ArrayViewD<'_, f64> {
55        self.data.view()
56    }
57
58    pub fn view_mut(&mut self) -> ArrayViewMutD<'_, f64> {
59        self.data.view_mut()
60    }
61
62    pub fn flatten(&self) -> Vec<f64> {
63        self.data.iter().copied().collect()
64    }
65
66    pub fn from_flat(
67        name: String,
68        flat_data: &[f64],
69        shape: &[usize],
70        units: String,
71        var_type: VariableType,
72    ) -> Result<Self> {
73        let expected_size: usize = shape.iter().product();
74        if flat_data.len() != expected_size {
75            return Err(PhiloteError::ShapeMismatch {
76                expected: vec![expected_size],
77                actual: vec![flat_data.len()],
78            });
79        }
80
81        let data = ArrayD::from_shape_vec(shape, flat_data.to_vec()).map_err(|e| {
82            PhiloteError::array_error(format!("Failed to create array from flat data: {}", e))
83        })?;
84
85        Ok(Self::new(name, data, units, var_type))
86    }
87}
88
89#[derive(Debug, Clone)]
90pub struct ArrayData {
91    pub name: String,
92    pub subname: Option<String>,
93    pub start: usize,
94    pub end: usize,
95    pub var_type: VariableType,
96    pub data: Vec<f64>,
97}
98
99impl ArrayData {
100    pub fn new(
101        name: String,
102        subname: Option<String>,
103        start: usize,
104        end: usize,
105        var_type: VariableType,
106        data: Vec<f64>,
107    ) -> Self {
108        Self {
109            name,
110            subname,
111            start,
112            end,
113            var_type,
114            data,
115        }
116    }
117
118    pub fn size(&self) -> usize {
119        self.data.len()
120    }
121}
122
123impl From<ArrayData> for Array {
124    fn from(array_data: ArrayData) -> Self {
125        Array {
126            name: array_data.name,
127            subname: array_data.subname.unwrap_or_default(),
128            start: array_data.start as i64,
129            end: array_data.end as i64,
130            r#type: array_data.var_type.into(),
131            data: array_data.data,
132        }
133    }
134}
135
136impl TryFrom<Array> for ArrayData {
137    type Error = PhiloteError;
138
139    fn try_from(array: Array) -> Result<Self> {
140        let var_type = VariableType::try_from(array.r#type).map_err(|_| {
141            PhiloteError::InvalidVariableType(format!("Invalid type: {}", array.r#type))
142        })?;
143
144        if array.data.is_empty() {
145            return Err(PhiloteError::array_error("Array contains no data"));
146        }
147
148        let subname = if array.subname.is_empty() {
149            None
150        } else {
151            Some(array.subname)
152        };
153
154        Ok(ArrayData::new(
155            array.name,
156            subname,
157            array.start as usize,
158            array.end as usize,
159            var_type,
160            array.data,
161        ))
162    }
163}
164
165#[derive(Debug, Clone, Copy)]
166pub struct StreamOptions {
167    pub max_double_per_slice: usize,
168}
169
170impl Default for StreamOptions {
171    fn default() -> Self {
172        Self {
173            max_double_per_slice: 1000,
174        }
175    }
176}
177
178impl From<crate::philote_info::StreamOptions> for StreamOptions {
179    fn from(opts: crate::philote_info::StreamOptions) -> Self {
180        Self {
181            max_double_per_slice: opts.num_double as usize,
182        }
183    }
184}
185
186impl From<StreamOptions> for crate::philote_info::StreamOptions {
187    fn from(opts: StreamOptions) -> Self {
188        crate::philote_info::StreamOptions {
189            num_double: opts.max_double_per_slice as i64,
190        }
191    }
192}
193
194#[derive(Debug, Clone)]
195pub struct PartialsInfo {
196    pub name: String,
197    pub subname: String,
198    pub shape: Vec<usize>,
199}
200
201impl PartialsInfo {
202    pub fn new(name: String, subname: String, shape: Vec<usize>) -> Self {
203        Self {
204            name,
205            subname,
206            shape,
207        }
208    }
209
210    pub fn size(&self) -> usize {
211        self.shape.iter().product()
212    }
213}
214
215impl From<PartialsInfo> for PartialsMetaData {
216    fn from(info: PartialsInfo) -> Self {
217        PartialsMetaData {
218            name: info.name,
219            subname: info.subname,
220            shape: info.shape.into_iter().map(|s| s as i64).collect(),
221        }
222    }
223}
224
225pub struct ArrayChunker {
226    chunk_size: usize,
227}
228
229impl ArrayChunker {
230    pub fn new(chunk_size: usize) -> Self {
231        Self { chunk_size }
232    }
233
234    pub fn chunk_array(&self, name: &str, data: &[f64], var_type: VariableType) -> Vec<ArrayData> {
235        let mut chunks = Vec::new();
236        let mut start = 0;
237
238        while start < data.len() {
239            let end = std::cmp::min(start + self.chunk_size, data.len()) - 1;
240            let chunk_data = data[start..=end].to_vec();
241
242            chunks.push(ArrayData::new(
243                name.to_string(),
244                None,
245                start,
246                end,
247                var_type,
248                chunk_data,
249            ));
250
251            start = end + 1;
252        }
253
254        chunks
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use ndarray::ArrayD;
262
263    // VariableData tests
264    #[test]
265    fn test_variable_data_new() {
266        let data = ArrayD::from_elem(vec![2, 3], 1.0);
267        let var = VariableData::new("x".to_string(), data, "m".to_string(), VariableType::KInput);
268        assert_eq!(var.name, "x");
269        assert_eq!(var.units, "m");
270        assert_eq!(var.shape(), &[2, 3]);
271        assert_eq!(var.size(), 6);
272    }
273
274    #[test]
275    fn test_variable_data_zeros() {
276        let var = VariableData::zeros(
277            "y".to_string(),
278            &[3, 4],
279            "kg".to_string(),
280            VariableType::KOutput,
281        );
282        assert_eq!(var.name, "y");
283        assert_eq!(var.shape(), &[3, 4]);
284        assert_eq!(var.size(), 12);
285        assert_eq!(var.flatten(), vec![0.0; 12]);
286    }
287
288    #[test]
289    fn test_variable_data_flatten() {
290        let data = ArrayD::from_shape_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
291        let var = VariableData::new("x".to_string(), data, "".to_string(), VariableType::KInput);
292        let flat = var.flatten();
293        assert_eq!(flat, vec![1.0, 2.0, 3.0, 4.0]);
294    }
295
296    #[test]
297    fn test_variable_data_from_flat() {
298        let flat_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
299        let var = VariableData::from_flat(
300            "z".to_string(),
301            &flat_data,
302            &[2, 3],
303            "".to_string(),
304            VariableType::KOutput,
305        )
306        .unwrap();
307        assert_eq!(var.shape(), &[2, 3]);
308        assert_eq!(var.flatten(), flat_data);
309    }
310
311    #[test]
312    fn test_variable_data_from_flat_shape_mismatch() {
313        let flat_data = vec![1.0, 2.0, 3.0];
314        let result = VariableData::from_flat(
315            "z".to_string(),
316            &flat_data,
317            &[2, 3], // Expects 6 elements, not 3
318            "".to_string(),
319            VariableType::KOutput,
320        );
321        assert!(result.is_err());
322        assert!(matches!(
323            result.unwrap_err(),
324            PhiloteError::ShapeMismatch { .. }
325        ));
326    }
327
328    #[test]
329    fn test_variable_data_view_mut() {
330        let data = ArrayD::from_elem(vec![2, 2], 0.0);
331        let mut var =
332            VariableData::new("x".to_string(), data, "".to_string(), VariableType::KInput);
333        {
334            let mut view = var.view_mut();
335            view[[0, 0]] = 5.0;
336        }
337        assert_eq!(var.data[[0, 0]], 5.0);
338    }
339
340    // ArrayData tests
341    #[test]
342    fn test_array_data_new() {
343        let data = vec![1.0, 2.0, 3.0];
344        let arr = ArrayData::new(
345            "x".to_string(),
346            None,
347            0,
348            2,
349            VariableType::KInput,
350            data.clone(),
351        );
352        assert_eq!(arr.name, "x");
353        assert_eq!(arr.subname, None);
354        assert_eq!(arr.start, 0);
355        assert_eq!(arr.end, 2);
356        assert_eq!(arr.size(), 3);
357        assert_eq!(arr.data, data);
358    }
359
360    #[test]
361    fn test_array_data_with_subname() {
362        let arr = ArrayData::new(
363            "f".to_string(),
364            Some("x".to_string()),
365            0,
366            4,
367            VariableType::KPartial,
368            vec![1.0, 2.0, 3.0, 4.0, 5.0],
369        );
370        assert_eq!(arr.subname, Some("x".to_string()));
371    }
372
373    #[test]
374    fn test_array_data_to_proto() {
375        let arr = ArrayData::new(
376            "x".to_string(),
377            Some("sub".to_string()),
378            0,
379            2,
380            VariableType::KInput,
381            vec![1.0, 2.0, 3.0],
382        );
383        let proto: Array = arr.into();
384        assert_eq!(proto.name, "x");
385        assert_eq!(proto.subname, "sub");
386        assert_eq!(proto.start, 0);
387        assert_eq!(proto.end, 2);
388        assert_eq!(proto.data, vec![1.0, 2.0, 3.0]);
389    }
390
391    #[test]
392    fn test_array_data_from_proto() {
393        let proto = Array {
394            name: "y".to_string(),
395            subname: "".to_string(),
396            start: 5,
397            end: 9,
398            r#type: VariableType::KOutput.into(),
399            data: vec![1.0, 2.0, 3.0, 4.0, 5.0],
400        };
401        let arr = ArrayData::try_from(proto).unwrap();
402        assert_eq!(arr.name, "y");
403        assert_eq!(arr.subname, None);
404        assert_eq!(arr.start, 5);
405        assert_eq!(arr.end, 9);
406        assert_eq!(arr.data, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
407    }
408
409    #[test]
410    fn test_array_data_from_proto_empty_data() {
411        let proto = Array {
412            name: "x".to_string(),
413            subname: "".to_string(),
414            start: 0,
415            end: 0,
416            r#type: VariableType::KInput.into(),
417            data: vec![],
418        };
419        let result = ArrayData::try_from(proto);
420        assert!(result.is_err());
421        assert!(matches!(result.unwrap_err(), PhiloteError::ArrayError(_)));
422    }
423
424    #[test]
425    fn test_array_data_from_proto_invalid_type() {
426        let proto = Array {
427            name: "x".to_string(),
428            subname: "".to_string(),
429            start: 0,
430            end: 2,
431            r#type: 999, // Invalid type
432            data: vec![1.0, 2.0, 3.0],
433        };
434        let result = ArrayData::try_from(proto);
435        assert!(result.is_err());
436        assert!(matches!(
437            result.unwrap_err(),
438            PhiloteError::InvalidVariableType(_)
439        ));
440    }
441
442    // StreamOptions tests
443    #[test]
444    fn test_stream_options_default() {
445        let opts = StreamOptions::default();
446        assert_eq!(opts.max_double_per_slice, 1000);
447    }
448
449    #[test]
450    fn test_stream_options_to_proto() {
451        let opts = StreamOptions {
452            max_double_per_slice: 500,
453        };
454        let proto: crate::philote_info::StreamOptions = opts.into();
455        assert_eq!(proto.num_double, 500);
456    }
457
458    #[test]
459    fn test_stream_options_from_proto() {
460        let proto = crate::philote_info::StreamOptions { num_double: 750 };
461        let opts: StreamOptions = proto.into();
462        assert_eq!(opts.max_double_per_slice, 750);
463    }
464
465    #[test]
466    fn test_stream_options_copy() {
467        let opts1 = StreamOptions::default();
468        let opts2 = opts1; // Should copy, not move
469        assert_eq!(opts1.max_double_per_slice, opts2.max_double_per_slice);
470    }
471
472    // PartialsInfo tests
473    #[test]
474    fn test_partials_info_new() {
475        let info = PartialsInfo::new("f".to_string(), "x".to_string(), vec![3, 4]);
476        assert_eq!(info.name, "f");
477        assert_eq!(info.subname, "x");
478        assert_eq!(info.shape, vec![3, 4]);
479        assert_eq!(info.size(), 12);
480    }
481
482    #[test]
483    fn test_partials_info_size_scalar() {
484        let info = PartialsInfo::new("f".to_string(), "x".to_string(), vec![1]);
485        assert_eq!(info.size(), 1);
486    }
487
488    #[test]
489    fn test_partials_info_to_proto() {
490        let info = PartialsInfo::new("df".to_string(), "dx".to_string(), vec![2, 3]);
491        let proto: PartialsMetaData = info.into();
492        assert_eq!(proto.name, "df");
493        assert_eq!(proto.subname, "dx");
494        assert_eq!(proto.shape, vec![2, 3]);
495    }
496
497    // ArrayChunker tests
498    #[test]
499    fn test_array_chunker_exact_chunks() {
500        let chunker = ArrayChunker::new(3);
501        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
502        let chunks = chunker.chunk_array("x", &data, VariableType::KInput);
503        assert_eq!(chunks.len(), 2);
504        assert_eq!(chunks[0].start, 0);
505        assert_eq!(chunks[0].end, 2);
506        assert_eq!(chunks[0].data, vec![1.0, 2.0, 3.0]);
507        assert_eq!(chunks[1].start, 3);
508        assert_eq!(chunks[1].end, 5);
509        assert_eq!(chunks[1].data, vec![4.0, 5.0, 6.0]);
510    }
511
512    #[test]
513    fn test_array_chunker_partial_last_chunk() {
514        let chunker = ArrayChunker::new(4);
515        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
516        let chunks = chunker.chunk_array("y", &data, VariableType::KOutput);
517        assert_eq!(chunks.len(), 2);
518        assert_eq!(chunks[0].data, vec![1.0, 2.0, 3.0, 4.0]);
519        assert_eq!(chunks[1].data, vec![5.0]);
520        assert_eq!(chunks[1].start, 4);
521        assert_eq!(chunks[1].end, 4);
522    }
523
524    #[test]
525    fn test_array_chunker_single_chunk() {
526        let chunker = ArrayChunker::new(10);
527        let data = vec![1.0, 2.0, 3.0];
528        let chunks = chunker.chunk_array("z", &data, VariableType::KInput);
529        assert_eq!(chunks.len(), 1);
530        assert_eq!(chunks[0].data, data);
531    }
532
533    #[test]
534    fn test_array_chunker_single_element() {
535        let chunker = ArrayChunker::new(5);
536        let data = vec![42.0];
537        let chunks = chunker.chunk_array("a", &data, VariableType::KInput);
538        assert_eq!(chunks.len(), 1);
539        assert_eq!(chunks[0].start, 0);
540        assert_eq!(chunks[0].end, 0);
541        assert_eq!(chunks[0].data, vec![42.0]);
542    }
543
544    #[test]
545    fn test_array_chunker_preserves_name_and_type() {
546        let chunker = ArrayChunker::new(2);
547        let data = vec![1.0, 2.0, 3.0];
548        let chunks = chunker.chunk_array("test", &data, VariableType::KPartial);
549        for chunk in chunks {
550            assert_eq!(chunk.name, "test");
551            assert_eq!(chunk.var_type, VariableType::KPartial);
552        }
553    }
554}