Skip to main content

cognee_llm/
dynamic_model.rs

1//! Dynamic graph model for cross-SDK schema sharing.
2//!
3//! Provides [`DynamicGraphModel`], a runtime representation of a graph model's
4//! JSON schema. This enables schema exchange between the Python and Rust cognee
5//! SDKs without requiring compiled types on both sides.
6//!
7//! The Python SDK can serialize a Pydantic model to JSON schema, send it to the
8//! Rust SDK as a [`DynamicGraphModel`], and vice versa. This mirrors Python's
9//! `graph_model_to_graph_schema()` / `graph_schema_to_graph_model()` from
10//! `cognee/shared/graph_model_utils.py`.
11//!
12//! # Usage
13//!
14//! ```
15//! use cognee_llm::DynamicGraphModel;
16//! use schemars::JsonSchema;
17//! use serde::{Deserialize, Serialize};
18//!
19//! // From a Rust type
20//! #[derive(Serialize, Deserialize, JsonSchema, Clone)]
21//! struct MyModel {
22//!     entities: Vec<String>,
23//! }
24//!
25//! let model = DynamicGraphModel::from_type::<MyModel>("MyModel");
26//! assert_eq!(model.name, "MyModel");
27//!
28//! // From a pre-existing JSON schema (e.g., received from Python)
29//! let schema = serde_json::json!({
30//!     "type": "object",
31//!     "properties": {
32//!         "name": { "type": "string" }
33//!     },
34//!     "required": ["name"]
35//! });
36//! let model = DynamicGraphModel::from_schema("ExternalModel", schema);
37//! ```
38
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use serde_json::Value;
42use thiserror::Error;
43
44use crate::schema::generate_json_schema;
45
46/// A runtime representation of a graph model's JSON schema.
47///
48/// Stores the JSON schema for a graph model so it can be serialized, transmitted
49/// between SDKs, and used for LLM structured output without requiring the
50/// concrete Rust type at runtime.
51///
52/// # Fields
53/// * `name` - Human-readable name for the model (e.g., "KnowledgeGraph", "ProgrammingLanguage")
54/// * `schema` - The JSON schema as a `serde_json::Value`
55/// * `description` - Optional description of what the model represents
56/// * `source` - Optional source identifier (e.g., "python-sdk", "rust-sdk")
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct DynamicGraphModel {
59    /// Human-readable name for the model.
60    pub name: String,
61
62    /// The JSON schema describing the model's structure.
63    pub schema: Value,
64
65    /// Optional description of what the model represents.
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub description: Option<String>,
68
69    /// Optional source identifier (e.g., "python-sdk", "rust-sdk").
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub source: Option<String>,
72}
73
74impl DynamicGraphModel {
75    /// Create a [`DynamicGraphModel`] from a Rust type that implements [`JsonSchema`].
76    ///
77    /// Generates the JSON schema at runtime via `schemars` and stores it alongside
78    /// the given name. The `source` field is automatically set to `"rust-sdk"`.
79    ///
80    /// # Arguments
81    /// * `name` - Human-readable name for the model
82    ///
83    /// # Example
84    /// ```
85    /// use cognee_llm::DynamicGraphModel;
86    /// use schemars::JsonSchema;
87    /// use serde::{Deserialize, Serialize};
88    ///
89    /// #[derive(Serialize, Deserialize, JsonSchema, Clone)]
90    /// struct PersonGraph {
91    ///     people: Vec<String>,
92    ///     relationships: Vec<(String, String)>,
93    /// }
94    ///
95    /// let model = DynamicGraphModel::from_type::<PersonGraph>("PersonGraph");
96    /// assert_eq!(model.name, "PersonGraph");
97    /// assert_eq!(model.source.as_deref(), Some("rust-sdk"));
98    /// ```
99    pub fn from_type<T: JsonSchema>(name: impl Into<String>) -> Self {
100        Self {
101            name: name.into(),
102            schema: generate_json_schema::<T>(),
103            description: None,
104            source: Some("rust-sdk".to_string()),
105        }
106    }
107
108    /// Create a [`DynamicGraphModel`] from a pre-existing JSON schema.
109    ///
110    /// Use this when receiving a schema from an external source (e.g., the Python SDK
111    /// serialized a Pydantic model to JSON schema).
112    ///
113    /// # Arguments
114    /// * `name` - Human-readable name for the model
115    /// * `schema` - The JSON schema as a `serde_json::Value`
116    ///
117    /// # Example
118    /// ```
119    /// use cognee_llm::DynamicGraphModel;
120    ///
121    /// let schema = serde_json::json!({
122    ///     "type": "object",
123    ///     "properties": {
124    ///         "name": { "type": "string" }
125    ///     },
126    ///     "required": ["name"]
127    /// });
128    /// let model = DynamicGraphModel::from_schema("ExternalModel", schema);
129    /// assert_eq!(model.name, "ExternalModel");
130    /// assert!(model.source.is_none());
131    /// ```
132    pub fn from_schema(name: impl Into<String>, schema: Value) -> Self {
133        Self {
134            name: name.into(),
135            schema,
136            description: None,
137            source: None,
138        }
139    }
140
141    /// Set an optional description on this model.
142    pub fn with_description(mut self, description: impl Into<String>) -> Self {
143        self.description = Some(description.into());
144        self
145    }
146
147    /// Set an optional source identifier on this model.
148    pub fn with_source(mut self, source: impl Into<String>) -> Self {
149        self.source = Some(source.into());
150        self
151    }
152
153    /// Check whether the schema has a `"properties"` key (i.e., looks like an object schema).
154    ///
155    /// This is a lightweight structural check, not full JSON Schema validation.
156    /// For full structural validation, deserialize with `serde_json::from_value::<T>()`
157    /// which enforces all type constraints.
158    pub fn has_properties(&self) -> bool {
159        self.schema.get("properties").is_some()
160    }
161
162    /// Get the list of required field names from the schema, if any.
163    ///
164    /// Returns `None` if the schema has no `"required"` key. Returns `Some(vec)`
165    /// with the field names otherwise.
166    pub fn required_fields(&self) -> Option<Vec<&str>> {
167        self.schema.get("required").and_then(|v| {
168            v.as_array().map(|arr| {
169                arr.iter()
170                    .filter_map(|item| item.as_str())
171                    .collect::<Vec<_>>()
172            })
173        })
174    }
175
176    /// Check whether a JSON value has all the required fields defined in this schema.
177    ///
178    /// This performs a lightweight check: it only verifies that required fields
179    /// exist as keys in the JSON object. It does **not** validate types or nested
180    /// structures. For full structural validation, use `serde_json::from_value::<T>()`.
181    ///
182    /// Returns `Ok(())` if all required fields are present (or if there are no
183    /// required fields). Returns `Err` with a message listing missing fields.
184    pub fn check_required_fields(&self, instance: &Value) -> Result<(), String> {
185        let required = match self.required_fields() {
186            Some(fields) => fields,
187            None => return Ok(()),
188        };
189
190        let obj = instance.as_object().ok_or_else(|| {
191            format!(
192                "Expected a JSON object for model '{}', got {}",
193                self.name,
194                value_type_name(instance)
195            )
196        })?;
197
198        let missing: Vec<&str> = required
199            .iter()
200            .filter(|field| !obj.contains_key(**field))
201            .copied()
202            .collect();
203
204        if missing.is_empty() {
205            Ok(())
206        } else {
207            Err(format!(
208                "Model '{}' is missing required fields: {}",
209                self.name,
210                missing.join(", ")
211            ))
212        }
213    }
214}
215
216// ─── graph_schema_to_graph_model ──────────────────────────────────────────────
217
218/// Errors emitted by [`graph_schema_to_graph_model`].
219#[derive(Debug, Error, PartialEq, Eq)]
220pub enum GraphModelError {
221    #[error("graph schema must be a JSON object, got {0}")]
222    NotAnObject(&'static str),
223
224    #[error("graph schema is missing required key `{0}`")]
225    MissingKey(&'static str),
226
227    #[error("graph schema field `{0}` must be a list, got {1}")]
228    NotAList(&'static str, &'static str),
229}
230
231/// Validate a JSON value against the canonical graph-model shape.
232///
233/// Mirrors Python's
234/// [`graph_schema_to_graph_model`](https://github.com/topoteretes/cognee/blob/main/cognee/shared/graph_model_utils.py)
235/// in *spirit only* — the Rust port does not generate runtime Pydantic classes
236/// (the LLM-router handler only ever uses the error path to distinguish
237/// "schema invalid" → 409 from "JSON parse error" → 422).
238///
239/// The validation rules are:
240/// - Top-level value must be a JSON object.
241/// - The object must carry an `entity_types` array.
242/// - The object must carry a `relationship_types` array.
243///
244/// On success returns `Ok(())` (the success value is unused by the handler).
245pub fn graph_schema_to_graph_model(value: &Value) -> Result<(), GraphModelError> {
246    let obj = match value {
247        Value::Object(map) => map,
248        _ => return Err(GraphModelError::NotAnObject(value_type_name(value))),
249    };
250
251    let entity_types = obj
252        .get("entity_types")
253        .ok_or(GraphModelError::MissingKey("entity_types"))?;
254    if !entity_types.is_array() {
255        return Err(GraphModelError::NotAList(
256            "entity_types",
257            value_type_name(entity_types),
258        ));
259    }
260
261    let relationship_types = obj
262        .get("relationship_types")
263        .ok_or(GraphModelError::MissingKey("relationship_types"))?;
264    if !relationship_types.is_array() {
265        return Err(GraphModelError::NotAList(
266            "relationship_types",
267            value_type_name(relationship_types),
268        ));
269    }
270
271    Ok(())
272}
273
274/// Return a human-readable name for a JSON value type.
275fn value_type_name(v: &Value) -> &'static str {
276    match v {
277        Value::Null => "null",
278        Value::Bool(_) => "boolean",
279        Value::Number(_) => "number",
280        Value::String(_) => "string",
281        Value::Array(_) => "array",
282        Value::Object(_) => "object",
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    #![allow(
289        clippy::unwrap_used,
290        clippy::expect_used,
291        reason = "test code — panics are acceptable"
292    )]
293    use super::*;
294
295    /// A KnowledgeGraph-like model for testing.
296    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
297    struct TestNode {
298        id: String,
299        name: String,
300        #[serde(rename = "type")]
301        node_type: String,
302        description: String,
303    }
304
305    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
306    struct TestEdge {
307        source_node_id: String,
308        target_node_id: String,
309        relationship_name: String,
310    }
311
312    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
313    struct TestKnowledgeGraph {
314        #[serde(default)]
315        nodes: Vec<TestNode>,
316        #[serde(default)]
317        edges: Vec<TestEdge>,
318    }
319
320    #[test]
321    fn test_from_type_produces_valid_schema() {
322        let model = DynamicGraphModel::from_type::<TestKnowledgeGraph>("KnowledgeGraph");
323
324        assert_eq!(model.name, "KnowledgeGraph");
325        assert_eq!(model.source.as_deref(), Some("rust-sdk"));
326        assert!(model.description.is_none());
327
328        // Schema should be an object with standard JSON Schema keys
329        assert!(model.schema.is_object());
330
331        // Should have "properties" containing "nodes" and "edges"
332        let props = &model.schema["properties"];
333        assert!(props.is_object(), "schema should have 'properties'");
334        assert!(
335            props.get("nodes").is_some(),
336            "schema should have 'nodes' property"
337        );
338        assert!(
339            props.get("edges").is_some(),
340            "schema should have 'edges' property"
341        );
342    }
343
344    #[test]
345    fn test_from_type_has_type_object() {
346        let model = DynamicGraphModel::from_type::<TestKnowledgeGraph>("KnowledgeGraph");
347        assert_eq!(model.schema["type"], "object");
348    }
349
350    #[test]
351    fn test_from_schema_with_arbitrary_schema() {
352        let schema = serde_json::json!({
353            "type": "object",
354            "properties": {
355                "language": { "type": "string" },
356                "version": { "type": "number" }
357            },
358            "required": ["language"]
359        });
360
361        let model = DynamicGraphModel::from_schema("ProgrammingLanguage", schema.clone());
362
363        assert_eq!(model.name, "ProgrammingLanguage");
364        assert!(model.source.is_none());
365        assert_eq!(model.schema, schema);
366    }
367
368    #[test]
369    fn test_round_trip_serialization() {
370        let original = DynamicGraphModel::from_type::<TestKnowledgeGraph>("KnowledgeGraph")
371            .with_description("A knowledge graph model")
372            .with_source("test-suite");
373
374        // Serialize to JSON string
375        let json_str = serde_json::to_string(&original).unwrap();
376
377        // Deserialize back
378        let restored: DynamicGraphModel = serde_json::from_str(&json_str).unwrap();
379
380        assert_eq!(restored.name, original.name);
381        assert_eq!(restored.schema, original.schema);
382        assert_eq!(restored.description, original.description);
383        assert_eq!(restored.source, original.source);
384    }
385
386    #[test]
387    fn test_round_trip_through_value() {
388        let original = DynamicGraphModel::from_type::<TestKnowledgeGraph>("KnowledgeGraph");
389
390        // Serialize to Value and back
391        let value = serde_json::to_value(&original).unwrap();
392        let restored: DynamicGraphModel = serde_json::from_value(value).unwrap();
393
394        assert_eq!(restored.name, original.name);
395        assert_eq!(restored.schema, original.schema);
396    }
397
398    #[test]
399    fn test_has_properties() {
400        let model = DynamicGraphModel::from_type::<TestKnowledgeGraph>("KnowledgeGraph");
401        assert!(model.has_properties());
402
403        let empty = DynamicGraphModel::from_schema("Empty", serde_json::json!({}));
404        assert!(!empty.has_properties());
405    }
406
407    #[test]
408    fn test_required_fields() {
409        let schema = serde_json::json!({
410            "type": "object",
411            "properties": {
412                "name": { "type": "string" },
413                "age": { "type": "integer" }
414            },
415            "required": ["name", "age"]
416        });
417        let model = DynamicGraphModel::from_schema("Person", schema);
418
419        let required = model.required_fields().unwrap();
420        assert_eq!(required, vec!["name", "age"]);
421    }
422
423    #[test]
424    fn test_required_fields_none_when_absent() {
425        let schema = serde_json::json!({
426            "type": "object",
427            "properties": {
428                "name": { "type": "string" }
429            }
430        });
431        let model = DynamicGraphModel::from_schema("Flexible", schema);
432        assert!(model.required_fields().is_none());
433    }
434
435    #[test]
436    fn test_check_required_fields_pass() {
437        let schema = serde_json::json!({
438            "type": "object",
439            "properties": {
440                "name": { "type": "string" },
441                "value": { "type": "number" }
442            },
443            "required": ["name", "value"]
444        });
445        let model = DynamicGraphModel::from_schema("Item", schema);
446
447        let instance = serde_json::json!({
448            "name": "test",
449            "value": 42,
450            "extra": true
451        });
452        assert!(model.check_required_fields(&instance).is_ok());
453    }
454
455    #[test]
456    fn test_check_required_fields_missing() {
457        let schema = serde_json::json!({
458            "type": "object",
459            "properties": {
460                "name": { "type": "string" },
461                "value": { "type": "number" }
462            },
463            "required": ["name", "value"]
464        });
465        let model = DynamicGraphModel::from_schema("Item", schema);
466
467        let instance = serde_json::json!({ "name": "test" });
468        let err = model.check_required_fields(&instance).unwrap_err();
469        assert!(
470            err.contains("value"),
471            "Error should mention missing field: {err}"
472        );
473    }
474
475    #[test]
476    fn test_check_required_fields_not_object() {
477        let schema = serde_json::json!({
478            "type": "object",
479            "required": ["name"]
480        });
481        let model = DynamicGraphModel::from_schema("Item", schema);
482
483        let instance = serde_json::json!("not an object");
484        let err = model.check_required_fields(&instance).unwrap_err();
485        assert!(
486            err.contains("Expected a JSON object"),
487            "Error should mention type mismatch: {err}"
488        );
489    }
490
491    #[test]
492    fn test_check_required_fields_no_required() {
493        let schema = serde_json::json!({
494            "type": "object",
495            "properties": { "name": { "type": "string" } }
496        });
497        let model = DynamicGraphModel::from_schema("Flexible", schema);
498
499        // Any object should pass when there are no required fields
500        let instance = serde_json::json!({});
501        assert!(model.check_required_fields(&instance).is_ok());
502    }
503
504    #[test]
505    fn test_builder_methods() {
506        let model = DynamicGraphModel::from_schema("Test", serde_json::json!({}))
507            .with_description("A test model")
508            .with_source("python-sdk");
509
510        assert_eq!(model.description.as_deref(), Some("A test model"));
511        assert_eq!(model.source.as_deref(), Some("python-sdk"));
512    }
513
514    #[test]
515    fn test_graph_schema_to_graph_model_accepts_canonical_shape() {
516        let value = serde_json::json!({
517            "entity_types": [{"name": "Person"}],
518            "relationship_types": [{"name": "WORKS_AT"}],
519        });
520        assert!(graph_schema_to_graph_model(&value).is_ok());
521    }
522
523    #[test]
524    fn test_graph_schema_to_graph_model_rejects_non_object() {
525        let value = serde_json::json!([]);
526        let err = graph_schema_to_graph_model(&value).unwrap_err();
527        assert!(matches!(err, GraphModelError::NotAnObject(_)));
528    }
529
530    #[test]
531    fn test_graph_schema_to_graph_model_missing_entity_types() {
532        let value = serde_json::json!({"relationship_types": []});
533        let err = graph_schema_to_graph_model(&value).unwrap_err();
534        assert_eq!(err, GraphModelError::MissingKey("entity_types"));
535    }
536
537    #[test]
538    fn test_graph_schema_to_graph_model_missing_relationship_types() {
539        let value = serde_json::json!({"entity_types": []});
540        let err = graph_schema_to_graph_model(&value).unwrap_err();
541        assert_eq!(err, GraphModelError::MissingKey("relationship_types"));
542    }
543
544    #[test]
545    fn test_graph_schema_to_graph_model_entity_types_must_be_array() {
546        let value = serde_json::json!({
547            "entity_types": "wrong",
548            "relationship_types": [],
549        });
550        let err = graph_schema_to_graph_model(&value).unwrap_err();
551        assert!(matches!(err, GraphModelError::NotAList("entity_types", _)));
552    }
553
554    #[test]
555    fn test_skip_serializing_none_fields() {
556        let model = DynamicGraphModel::from_schema("Minimal", serde_json::json!({}));
557        let json = serde_json::to_value(&model).unwrap();
558        let obj = json.as_object().unwrap();
559
560        // description and source should not be present when None
561        assert!(!obj.contains_key("description"));
562        assert!(!obj.contains_key("source"));
563
564        // name and schema should always be present
565        assert!(obj.contains_key("name"));
566        assert!(obj.contains_key("schema"));
567    }
568}