Skip to main content

shape_ext_python/
marshaling.rs

1//! WireValue <-> Python object conversion.
2
3/// Shape type name -> Python type hint string.
4pub fn shape_type_to_python_hint(shape_type: &str) -> String {
5    match shape_type {
6        "number" => "float".to_string(),
7        "int" => "int".to_string(),
8        "bool" => "bool".to_string(),
9        "string" => "str".to_string(),
10        "none" => "None".to_string(),
11        s if s.starts_with("Array<") => {
12            let inner = &s[6..s.len() - 1];
13            format!("list[{}]", shape_type_to_python_hint(inner))
14        }
15        _ => "object".to_string(),
16    }
17}
18
19/// Convert an rmpv::Value to a Python object.
20#[cfg(feature = "pyo3")]
21pub fn msgpack_to_pyobject(
22    py: pyo3::Python<'_>,
23    value: &rmpv::Value,
24) -> Result<pyo3::Py<pyo3::PyAny>, String> {
25    use pyo3::IntoPyObject;
26
27    match value {
28        rmpv::Value::Nil => Ok(py.None().into_pyobject(py).unwrap().unbind().into()),
29        rmpv::Value::Boolean(b) => Ok(b.into_pyobject(py).unwrap().to_owned().unbind().into()),
30        rmpv::Value::Integer(i) => {
31            if let Some(n) = i.as_i64() {
32                Ok(n.into_pyobject(py).unwrap().unbind().into())
33            } else if let Some(n) = i.as_u64() {
34                Ok(n.into_pyobject(py).unwrap().unbind().into())
35            } else {
36                Ok(py.None().into_pyobject(py).unwrap().unbind().into())
37            }
38        }
39        rmpv::Value::F32(f) => Ok((*f as f64).into_pyobject(py).unwrap().unbind().into()),
40        rmpv::Value::F64(f) => Ok(f.into_pyobject(py).unwrap().unbind().into()),
41        rmpv::Value::String(s) => {
42            if let Some(s) = s.as_str() {
43                Ok(s.into_pyobject(py).unwrap().unbind().into())
44            } else {
45                Ok(py.None().into_pyobject(py).unwrap().unbind().into())
46            }
47        }
48        rmpv::Value::Array(arr) => {
49            let items: Vec<pyo3::Py<pyo3::PyAny>> = arr
50                .iter()
51                .map(|v| msgpack_to_pyobject(py, v))
52                .collect::<Result<_, _>>()?;
53            let list = pyo3::types::PyList::new(py, &items)
54                .map_err(|e| format!("Failed to create Python list: {}", e))?;
55            Ok(list.unbind().into())
56        }
57        rmpv::Value::Map(entries) => {
58            use pyo3::types::PyDictMethods;
59            let dict = pyo3::types::PyDict::new(py);
60            for (k, v) in entries {
61                let py_key = msgpack_to_pyobject(py, k)?;
62                let py_val = msgpack_to_pyobject(py, v)?;
63                dict.set_item(py_key, py_val)
64                    .map_err(|e| format!("Failed to set dict item: {}", e))?;
65            }
66            Ok(dict.unbind().into())
67        }
68        rmpv::Value::Binary(_) | rmpv::Value::Ext(_, _) => {
69            Ok(py.None().into_pyobject(py).unwrap().unbind().into())
70        }
71    }
72}
73
74/// Convert a Python object to an rmpv::Value (untyped path).
75#[cfg(feature = "pyo3")]
76pub fn pyobject_to_msgpack(
77    py: pyo3::Python<'_>,
78    obj: &pyo3::Bound<'_, pyo3::PyAny>,
79) -> Result<rmpv::Value, String> {
80    use pyo3::types::*;
81
82    // Check bool BEFORE int (bool is subclass of int in Python)
83    if obj.is_instance_of::<PyBool>() {
84        let b: bool = obj
85            .extract()
86            .map_err(|e| format!("Failed to extract bool: {}", e))?;
87        return Ok(rmpv::Value::Boolean(b));
88    }
89
90    if obj.is_instance_of::<PyInt>() {
91        let i: i64 = obj
92            .extract()
93            .map_err(|e| format!("Failed to extract int: {}", e))?;
94        return Ok(rmpv::Value::Integer(rmpv::Integer::from(i)));
95    }
96
97    if obj.is_instance_of::<PyFloat>() {
98        let f: f64 = obj
99            .extract()
100            .map_err(|e| format!("Failed to extract float: {}", e))?;
101        return Ok(rmpv::Value::F64(f));
102    }
103
104    if obj.is_instance_of::<PyString>() {
105        let s: String = obj
106            .extract()
107            .map_err(|e| format!("Failed to extract string: {}", e))?;
108        return Ok(rmpv::Value::String(rmpv::Utf8String::from(s)));
109    }
110
111    if obj.is_none() {
112        return Ok(rmpv::Value::Nil);
113    }
114
115    if obj.is_instance_of::<PyList>() {
116        let list = obj
117            .cast::<PyList>()
118            .map_err(|e| format!("Failed to downcast list: {}", e))?;
119        let items: Vec<rmpv::Value> = list
120            .iter()
121            .map(|item| pyobject_to_msgpack(py, &item))
122            .collect::<Result<_, _>>()?;
123        return Ok(rmpv::Value::Array(items));
124    }
125
126    if obj.is_instance_of::<PyDict>() {
127        let dict = obj
128            .cast::<PyDict>()
129            .map_err(|e| format!("Failed to downcast dict: {}", e))?;
130        let entries: Vec<(rmpv::Value, rmpv::Value)> = dict
131            .iter()
132            .map(|(k, v)| {
133                let mk = pyobject_to_msgpack(py, &k)?;
134                let mv = pyobject_to_msgpack(py, &v)?;
135                Ok((mk, mv))
136            })
137            .collect::<Result<_, String>>()?;
138        return Ok(rmpv::Value::Map(entries));
139    }
140
141    // Fallback: try to convert to string representation
142    Ok(rmpv::Value::Nil)
143}
144
145// ============================================================================
146// Type-aware marshalling
147// ============================================================================
148
149/// Strip `Result<...>` wrapper from a type string, returning the inner type.
150pub fn strip_result_wrapper(s: &str) -> &str {
151    if s.starts_with("Result<") && s.ends_with('>') {
152        &s[7..s.len() - 1]
153    } else {
154        s
155    }
156}
157
158/// Extract `T` from `Array<T>`, returning `None` if `s` is not an Array type.
159fn strip_array_wrapper(s: &str) -> Option<&str> {
160    if s.starts_with("Array<") && s.ends_with('>') {
161        Some(&s[6..s.len() - 1])
162    } else {
163        None
164    }
165}
166
167/// Parse `{f1: T1, f2: T2}` to a Vec of (field_name, field_type) pairs.
168fn parse_object_fields(s: &str) -> Vec<(&str, &str)> {
169    let s = s.trim();
170    if !s.starts_with('{') || !s.ends_with('}') {
171        return Vec::new();
172    }
173    let inner = s[1..s.len() - 1].trim();
174    if inner.is_empty() {
175        return Vec::new();
176    }
177
178    let mut fields = Vec::new();
179    let mut depth = 0i32;
180    let mut start = 0;
181
182    // Split on commas, respecting nested angle brackets and braces
183    for (i, ch) in inner.char_indices() {
184        match ch {
185            '<' | '{' => depth += 1,
186            '>' | '}' => depth -= 1,
187            ',' if depth == 0 => {
188                if let Some(pair) = parse_single_field(inner[start..i].trim()) {
189                    fields.push(pair);
190                }
191                start = i + 1;
192            }
193            _ => {}
194        }
195    }
196    // Last field
197    if let Some(pair) = parse_single_field(inner[start..].trim()) {
198        fields.push(pair);
199    }
200
201    fields
202}
203
204/// Parse a single `name: type` or `name?: type` field spec.
205fn parse_single_field(s: &str) -> Option<(&str, &str)> {
206    let s = s.trim();
207    let colon_pos = s.find(':')?;
208    let name = s[..colon_pos].trim().trim_end_matches('?');
209    let typ = s[colon_pos + 1..].trim();
210    if name.is_empty() || typ.is_empty() {
211        return None;
212    }
213    Some((name, typ))
214}
215
216/// Convert a Python object to an rmpv::Value using the declared Shape return type
217/// for validation and coercion.
218///
219/// This is the typed marshalling path. It validates that the Python value matches
220/// the declared type and coerces where the rules allow (e.g. float 3.0 -> int 3).
221#[cfg(feature = "pyo3")]
222pub fn pyobject_to_typed_msgpack(
223    py: pyo3::Python<'_>,
224    obj: &pyo3::Bound<'_, pyo3::PyAny>,
225    target_type: &str,
226) -> Result<rmpv::Value, String> {
227    let inner = strip_result_wrapper(target_type);
228    convert_with_type(py, obj, inner)
229}
230
231#[cfg(feature = "pyo3")]
232fn convert_with_type(
233    py: pyo3::Python<'_>,
234    obj: &pyo3::Bound<'_, pyo3::PyAny>,
235    target: &str,
236) -> Result<rmpv::Value, String> {
237    use pyo3::types::*;
238
239    // Handle None/nil
240    if obj.is_none() {
241        return match target {
242            "none" => Ok(rmpv::Value::Nil),
243            _ => Err(format!("expected {}, got None", target)),
244        };
245    }
246
247    match target {
248        "int" => {
249            // Bool is NOT coerced to int — reject it explicitly
250            if obj.is_instance_of::<PyBool>() {
251                return Err("expected int, got bool (bool is not coerced to int)".to_string());
252            }
253            if obj.is_instance_of::<PyInt>() {
254                let i: i64 = obj
255                    .extract()
256                    .map_err(|e| format!("Failed to extract int: {}", e))?;
257                return Ok(rmpv::Value::Integer(rmpv::Integer::from(i)));
258            }
259            // Coerce float with integer value (3.0) -> int
260            if obj.is_instance_of::<PyFloat>() {
261                let f: f64 = obj
262                    .extract()
263                    .map_err(|e| format!("Failed to extract float: {}", e))?;
264                if f.fract() == 0.0 && f >= i64::MIN as f64 && f <= i64::MAX as f64 {
265                    return Ok(rmpv::Value::Integer(rmpv::Integer::from(f as i64)));
266                }
267                return Err(format!(
268                    "expected int, got float {} (not an integer value)",
269                    f
270                ));
271            }
272            Err(format!("expected int, got {}", py_type_name(obj)))
273        }
274
275        "float" | "number" => {
276            if obj.is_instance_of::<PyBool>() {
277                return Err(format!("expected {}, got bool", target));
278            }
279            if obj.is_instance_of::<PyFloat>() {
280                let f: f64 = obj
281                    .extract()
282                    .map_err(|e| format!("Failed to extract float: {}", e))?;
283                return Ok(rmpv::Value::F64(f));
284            }
285            // Coerce int -> float
286            if obj.is_instance_of::<PyInt>() {
287                let i: i64 = obj
288                    .extract()
289                    .map_err(|e| format!("Failed to extract int: {}", e))?;
290                return Ok(rmpv::Value::F64(i as f64));
291            }
292            Err(format!("expected {}, got {}", target, py_type_name(obj)))
293        }
294
295        "string" => {
296            if obj.is_instance_of::<PyString>() {
297                let s: String = obj
298                    .extract()
299                    .map_err(|e| format!("Failed to extract string: {}", e))?;
300                return Ok(rmpv::Value::String(rmpv::Utf8String::from(s)));
301            }
302            Err(format!("expected string, got {}", py_type_name(obj)))
303        }
304
305        "bool" => {
306            if obj.is_instance_of::<PyBool>() {
307                let b: bool = obj
308                    .extract()
309                    .map_err(|e| format!("Failed to extract bool: {}", e))?;
310                return Ok(rmpv::Value::Boolean(b));
311            }
312            Err(format!("expected bool, got {}", py_type_name(obj)))
313        }
314
315        "none" => Err(format!("expected none, got {}", py_type_name(obj))),
316
317        // Array<T>
318        s if strip_array_wrapper(s).is_some() => {
319            let elem_type = strip_array_wrapper(s).unwrap();
320            if !obj.is_instance_of::<PyList>() {
321                return Err(format!("expected Array, got {}", py_type_name(obj)));
322            }
323            let list = obj
324                .cast::<PyList>()
325                .map_err(|e| format!("Failed to downcast list: {}", e))?;
326            let items: Vec<rmpv::Value> = list
327                .iter()
328                .enumerate()
329                .map(|(i, item)| {
330                    convert_with_type(py, &item, elem_type)
331                        .map_err(|e| format!("Array element [{}]: {}", i, e))
332                })
333                .collect::<Result<_, _>>()?;
334            Ok(rmpv::Value::Array(items))
335        }
336
337        // Object type: {f1: T1, f2: T2, ...}
338        s if s.starts_with('{') && s.ends_with('}') => {
339            if !obj.is_instance_of::<PyDict>() {
340                return Err(format!("expected object, got {}", py_type_name(obj)));
341            }
342            let dict = obj
343                .cast::<PyDict>()
344                .map_err(|e| format!("Failed to downcast dict: {}", e))?;
345
346            let declared_fields = parse_object_fields(s);
347            if declared_fields.is_empty() {
348                // Empty object declaration or parse failure — fall back to untyped
349                return pyobject_to_msgpack(py, obj);
350            }
351
352            let mut entries = Vec::with_capacity(declared_fields.len());
353            for (field_name, field_type) in &declared_fields {
354                let key_obj = pyo3::types::PyString::new(py, field_name);
355                let value_obj = dict
356                    .get_item(&key_obj)
357                    .map_err(|e| format!("Failed to get field '{}': {}", field_name, e))?;
358                let value_obj = value_obj.ok_or_else(|| {
359                    format!("missing required field '{}' in returned dict", field_name)
360                })?;
361                let typed_val = convert_with_type(py, &value_obj, field_type)
362                    .map_err(|e| format!("field '{}': {}", field_name, e))?;
363                entries.push((
364                    rmpv::Value::String(rmpv::Utf8String::from(field_name.to_string())),
365                    typed_val,
366                ));
367            }
368            Ok(rmpv::Value::Map(entries))
369        }
370
371        // "object" or unknown — fall back to untyped
372        _ => pyobject_to_msgpack(py, obj),
373    }
374}
375
376/// Get a human-readable Python type name for error messages.
377#[cfg(feature = "pyo3")]
378fn py_type_name(obj: &pyo3::Bound<'_, pyo3::PyAny>) -> &'static str {
379    use pyo3::types::*;
380    if obj.is_instance_of::<PyBool>() {
381        "bool"
382    } else if obj.is_instance_of::<PyInt>() {
383        "int"
384    } else if obj.is_instance_of::<PyFloat>() {
385        "float"
386    } else if obj.is_instance_of::<PyString>() {
387        "string"
388    } else if obj.is_instance_of::<PyList>() {
389        "list"
390    } else if obj.is_instance_of::<PyDict>() {
391        "dict"
392    } else if obj.is_none() {
393        "None"
394    } else {
395        "object"
396    }
397}