Skip to main content

tl_ai/
model.rs

1// ThinkingLanguage — Model type
2// Represents trained ML models (linfa, ONNX, LLM endpoints).
3
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fmt;
7use std::fs;
8use std::path::{Path, PathBuf};
9
10/// A trained model in ThinkingLanguage.
11#[derive(Clone)]
12pub enum TlModel {
13    /// An ONNX model loaded from disk.
14    Onnx { path: PathBuf, metadata: ModelMeta },
15    /// A linfa-trained model (serialized).
16    Linfa {
17        kind: LinfaKind,
18        data: Vec<u8>,
19        metadata: ModelMeta,
20    },
21    /// An LLM endpoint (Claude, OpenAI, etc.)
22    LlmEndpoint {
23        provider: String,
24        model_name: String,
25        api_key: Option<String>,
26    },
27}
28
29impl fmt::Debug for TlModel {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        match self {
32            TlModel::Onnx { metadata, .. } => write!(f, "<model:onnx {}>", metadata.name),
33            TlModel::Linfa { kind, metadata, .. } => {
34                write!(f, "<model:{kind:?} {}>", metadata.name)
35            }
36            TlModel::LlmEndpoint {
37                provider,
38                model_name,
39                ..
40            } => write!(f, "<model:llm {provider}/{model_name}>"),
41        }
42    }
43}
44
45impl fmt::Display for TlModel {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        match self {
48            TlModel::Onnx { metadata, .. } => write!(f, "<model {}>", metadata.name),
49            TlModel::Linfa { metadata, .. } => write!(f, "<model {}>", metadata.name),
50            TlModel::LlmEndpoint {
51                provider,
52                model_name,
53                ..
54            } => write!(f, "<model {provider}/{model_name}>"),
55        }
56    }
57}
58
59/// What kind of linfa model.
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
61pub enum LinfaKind {
62    LinearRegression,
63    LogisticRegression,
64    DecisionTree,
65    RandomForest,
66    KMeans,
67    Knn,
68    NaiveBayes,
69    Dbscan,
70    Ridge,
71    GradientBoosting,
72}
73
74/// Model metadata.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct ModelMeta {
77    pub name: String,
78    pub version: String,
79    pub created_at: String,
80    pub features: Vec<String>,
81    pub target: String,
82    pub metrics: HashMap<String, f64>,
83}
84
85impl Default for ModelMeta {
86    fn default() -> Self {
87        ModelMeta {
88            name: String::new(),
89            version: "0.1.0".to_string(),
90            created_at: String::new(),
91            features: Vec::new(),
92            target: String::new(),
93            metrics: HashMap::new(),
94        }
95    }
96}
97
98impl TlModel {
99    /// Save a model to a .tlmodel directory.
100    pub fn save(&self, path: &Path) -> Result<(), String> {
101        fs::create_dir_all(path).map_err(|e| format!("Failed to create dir: {e}"))?;
102
103        match self {
104            TlModel::Linfa {
105                kind,
106                data,
107                metadata,
108            } => {
109                let meta = serde_json::json!({
110                    "type": "linfa",
111                    "kind": kind,
112                    "metadata": metadata,
113                });
114                fs::write(
115                    path.join("metadata.json"),
116                    serde_json::to_string_pretty(&meta).unwrap(),
117                )
118                .map_err(|e| format!("Failed to write metadata: {e}"))?;
119                fs::write(path.join("model.bin"), data)
120                    .map_err(|e| format!("Failed to write model: {e}"))?;
121            }
122            TlModel::Onnx {
123                path: onnx_path,
124                metadata,
125            } => {
126                let meta = serde_json::json!({
127                    "type": "onnx",
128                    "onnx_path": onnx_path.display().to_string(),
129                    "metadata": metadata,
130                });
131                fs::write(
132                    path.join("metadata.json"),
133                    serde_json::to_string_pretty(&meta).unwrap(),
134                )
135                .map_err(|e| format!("Failed to write metadata: {e}"))?;
136                // Copy the ONNX file if it exists
137                if onnx_path.exists() {
138                    fs::copy(onnx_path, path.join("model.onnx"))
139                        .map_err(|e| format!("Failed to copy ONNX model: {e}"))?;
140                }
141            }
142            TlModel::LlmEndpoint {
143                provider,
144                model_name,
145                ..
146            } => {
147                let meta = serde_json::json!({
148                    "type": "llm",
149                    "provider": provider,
150                    "model_name": model_name,
151                });
152                fs::write(
153                    path.join("metadata.json"),
154                    serde_json::to_string_pretty(&meta).unwrap(),
155                )
156                .map_err(|e| format!("Failed to write metadata: {e}"))?;
157            }
158        }
159        Ok(())
160    }
161
162    /// Load a model from a .tlmodel directory.
163    pub fn load(path: &Path) -> Result<Self, String> {
164        let meta_path = path.join("metadata.json");
165        let meta_str =
166            fs::read_to_string(&meta_path).map_err(|e| format!("Failed to read metadata: {e}"))?;
167        let meta: serde_json::Value =
168            serde_json::from_str(&meta_str).map_err(|e| format!("Invalid metadata: {e}"))?;
169
170        let model_type = meta["type"].as_str().ok_or("Missing 'type' in metadata")?;
171
172        match model_type {
173            "linfa" => {
174                let kind: LinfaKind = serde_json::from_value(meta["kind"].clone())
175                    .map_err(|e| format!("Invalid linfa kind: {e}"))?;
176                let metadata: ModelMeta = serde_json::from_value(meta["metadata"].clone())
177                    .map_err(|e| format!("Invalid metadata: {e}"))?;
178                let data = fs::read(path.join("model.bin"))
179                    .map_err(|e| format!("Failed to read model: {e}"))?;
180                Ok(TlModel::Linfa {
181                    kind,
182                    data,
183                    metadata,
184                })
185            }
186            "onnx" => {
187                let onnx_path = path.join("model.onnx");
188                let metadata: ModelMeta = serde_json::from_value(meta["metadata"].clone())
189                    .map_err(|e| format!("Invalid metadata: {e}"))?;
190                Ok(TlModel::Onnx {
191                    path: onnx_path,
192                    metadata,
193                })
194            }
195            "llm" => {
196                let provider = meta["provider"].as_str().unwrap_or("unknown").to_string();
197                let model_name = meta["model_name"].as_str().unwrap_or("unknown").to_string();
198                Ok(TlModel::LlmEndpoint {
199                    provider,
200                    model_name,
201                    api_key: None,
202                })
203            }
204            _ => Err(format!("Unknown model type: {model_type}")),
205        }
206    }
207
208    /// Get model metadata (if available).
209    pub fn metadata(&self) -> Option<&ModelMeta> {
210        match self {
211            TlModel::Onnx { metadata, .. } => Some(metadata),
212            TlModel::Linfa { metadata, .. } => Some(metadata),
213            TlModel::LlmEndpoint { .. } => None,
214        }
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn test_save_load_linfa_model() {
224        let dir = tempfile::tempdir().unwrap();
225        let model_path = dir.path().join("test.tlmodel");
226
227        let model = TlModel::Linfa {
228            kind: LinfaKind::LinearRegression,
229            data: vec![1, 2, 3, 4],
230            metadata: ModelMeta {
231                name: "test_model".to_string(),
232                version: "0.1.0".to_string(),
233                created_at: "2024-01-01".to_string(),
234                features: vec!["x1".to_string(), "x2".to_string()],
235                target: "y".to_string(),
236                metrics: {
237                    let mut m = HashMap::new();
238                    m.insert("r2".to_string(), 0.95);
239                    m
240                },
241            },
242        };
243
244        model.save(&model_path).unwrap();
245        let loaded = TlModel::load(&model_path).unwrap();
246
247        if let TlModel::Linfa {
248            kind,
249            data,
250            metadata,
251        } = loaded
252        {
253            assert_eq!(kind, LinfaKind::LinearRegression);
254            assert_eq!(data, vec![1, 2, 3, 4]);
255            assert_eq!(metadata.name, "test_model");
256            assert_eq!(metadata.features.len(), 2);
257            assert!((metadata.metrics["r2"] - 0.95).abs() < 1e-10);
258        } else {
259            panic!("Expected Linfa model");
260        }
261    }
262
263    #[test]
264    fn test_model_display() {
265        let model = TlModel::Linfa {
266            kind: LinfaKind::LinearRegression,
267            data: vec![],
268            metadata: ModelMeta {
269                name: "my_model".to_string(),
270                ..Default::default()
271            },
272        };
273        assert_eq!(format!("{model}"), "<model my_model>");
274    }
275}