Skip to main content

cvkg_render_gpu/
ai.rs

1use crate::material::MaterialValidationConfig;
2use crate::material::{MaterialGraph, MaterialSocket};
3
4#[derive(Debug)]
5pub enum GeneratorError {
6    ChannelClosed,
7    UnknownNodeType(String),
8    UnknownNode(String),
9    ValidationFailed(String),
10    Other(String),
11}
12
13/// A stable API for AI systems (LLMs, diffusion models) to generate
14/// and inject materials or subgraphs into the Kvasir renderer.
15pub trait KvasirGenerator: Send + Sync {
16    fn generate(&self) -> Result<(), GeneratorError>;
17}
18
19/// JSON representation of a Material Graph, allowing AI to author materials safely.
20#[derive(Debug, serde::Deserialize)]
21pub struct MaterialGraphSpec {
22    pub nodes: Vec<NodeSpec>,
23    pub edges: Vec<EdgeSpec>,
24}
25
26#[derive(Debug, serde::Deserialize)]
27pub struct NodeSpec {
28    pub id: String,
29    pub kind: String,
30    pub params: std::collections::HashMap<String, serde_json::Value>,
31}
32
33#[derive(Debug, serde::Deserialize)]
34pub struct EdgeSpec {
35    pub from_node: String,
36    pub from_socket: String,
37    pub to_node: String,
38    pub to_socket: String,
39}
40
41impl MaterialGraphSpec {
42    pub fn build_graph(&self) -> Result<MaterialGraph, GeneratorError> {
43        let mut mat = MaterialGraph::new();
44        let mut node_map = std::collections::HashMap::new();
45
46        for node_spec in &self.nodes {
47            let key = match node_spec.kind.as_str() {
48                "SolidColor" => {
49                    let r = node_spec
50                        .params
51                        .get("r")
52                        .and_then(|v| v.as_f64())
53                        .unwrap_or(1.0) as f32;
54                    let g = node_spec
55                        .params
56                        .get("g")
57                        .and_then(|v| v.as_f64())
58                        .unwrap_or(1.0) as f32;
59                    let b = node_spec
60                        .params
61                        .get("b")
62                        .and_then(|v| v.as_f64())
63                        .unwrap_or(1.0) as f32;
64                    let a = node_spec
65                        .params
66                        .get("a")
67                        .and_then(|v| v.as_f64())
68                        .unwrap_or(1.0) as f32;
69                    mat.add_node(crate::material::MaterialOp::ConstantColor { r, g, b, a })
70                }
71                "Output" => {
72                    // special node just to mark output
73                    // since Output is an operation, wait, MaterialGraph sets output by id
74                    u32::MAX
75                }
76                kind => return Err(GeneratorError::UnknownNodeType(kind.to_string())),
77            };
78            if key != u32::MAX {
79                node_map.insert(node_spec.id.clone(), key);
80            } else {
81                // If this is the Output spec node, it means whatever it connects to is the output.
82                // Wait, it's easier to just find the edge going into Output.
83            }
84        }
85
86        // Output node discovery
87        let output_edge = self.edges.iter().find(|e| e.to_node == "Output");
88        if let Some(edge) = output_edge {
89            let from = node_map
90                .get(&edge.from_node)
91                .ok_or_else(|| GeneratorError::UnknownNode(edge.from_node.clone()))?;
92            mat.set_output(*from);
93        } else {
94            return Err(GeneratorError::ValidationFailed(
95                "No edge to Output node".into(),
96            ));
97        }
98
99        for edge in &self.edges {
100            if edge.to_node == "Output" {
101                continue;
102            }
103            let from = node_map
104                .get(&edge.from_node)
105                .ok_or_else(|| GeneratorError::UnknownNode(edge.from_node.clone()))?;
106            let to = node_map
107                .get(&edge.to_node)
108                .ok_or_else(|| GeneratorError::UnknownNode(edge.to_node.clone()))?;
109
110            mat.connect(*from, MaterialSocket::Color, *to, MaterialSocket::Color);
111        }
112
113        mat.validate_with_config(&MaterialValidationConfig {
114            max_nodes: 32,
115            max_edges: 64,
116        })
117        .map_err(|e| GeneratorError::ValidationFailed(e.to_string()))?;
118
119        Ok(mat)
120    }
121}