1use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fmt;
7use std::fs;
8use std::path::{Path, PathBuf};
9
10#[derive(Clone)]
12pub enum TlModel {
13 Onnx { path: PathBuf, metadata: ModelMeta },
15 Linfa {
17 kind: LinfaKind,
18 data: Vec<u8>,
19 metadata: ModelMeta,
20 },
21 LlmEndpoint {
23 provider: String,
24 model_name: String,
25 api_key: Option<String>,
26 },
27}
28
29impl fmt::Debug for TlModel {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 match self {
32 TlModel::Onnx { metadata, .. } => write!(f, "<model:onnx {}>", metadata.name),
33 TlModel::Linfa { kind, metadata, .. } => {
34 write!(f, "<model:{kind:?} {}>", metadata.name)
35 }
36 TlModel::LlmEndpoint {
37 provider,
38 model_name,
39 ..
40 } => write!(f, "<model:llm {provider}/{model_name}>"),
41 }
42 }
43}
44
45impl fmt::Display for TlModel {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match self {
48 TlModel::Onnx { metadata, .. } => write!(f, "<model {}>", metadata.name),
49 TlModel::Linfa { metadata, .. } => write!(f, "<model {}>", metadata.name),
50 TlModel::LlmEndpoint {
51 provider,
52 model_name,
53 ..
54 } => write!(f, "<model {provider}/{model_name}>"),
55 }
56 }
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
61pub enum LinfaKind {
62 LinearRegression,
63 LogisticRegression,
64 DecisionTree,
65 RandomForest,
66 KMeans,
67 Knn,
68 NaiveBayes,
69 Dbscan,
70 Ridge,
71 GradientBoosting,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct ModelMeta {
77 pub name: String,
78 pub version: String,
79 pub created_at: String,
80 pub features: Vec<String>,
81 pub target: String,
82 pub metrics: HashMap<String, f64>,
83}
84
85impl Default for ModelMeta {
86 fn default() -> Self {
87 ModelMeta {
88 name: String::new(),
89 version: "0.1.0".to_string(),
90 created_at: String::new(),
91 features: Vec::new(),
92 target: String::new(),
93 metrics: HashMap::new(),
94 }
95 }
96}
97
98impl TlModel {
99 pub fn save(&self, path: &Path) -> Result<(), String> {
101 fs::create_dir_all(path).map_err(|e| format!("Failed to create dir: {e}"))?;
102
103 match self {
104 TlModel::Linfa {
105 kind,
106 data,
107 metadata,
108 } => {
109 let meta = serde_json::json!({
110 "type": "linfa",
111 "kind": kind,
112 "metadata": metadata,
113 });
114 fs::write(
115 path.join("metadata.json"),
116 serde_json::to_string_pretty(&meta).unwrap(),
117 )
118 .map_err(|e| format!("Failed to write metadata: {e}"))?;
119 fs::write(path.join("model.bin"), data)
120 .map_err(|e| format!("Failed to write model: {e}"))?;
121 }
122 TlModel::Onnx {
123 path: onnx_path,
124 metadata,
125 } => {
126 let meta = serde_json::json!({
127 "type": "onnx",
128 "onnx_path": onnx_path.display().to_string(),
129 "metadata": metadata,
130 });
131 fs::write(
132 path.join("metadata.json"),
133 serde_json::to_string_pretty(&meta).unwrap(),
134 )
135 .map_err(|e| format!("Failed to write metadata: {e}"))?;
136 if onnx_path.exists() {
138 fs::copy(onnx_path, path.join("model.onnx"))
139 .map_err(|e| format!("Failed to copy ONNX model: {e}"))?;
140 }
141 }
142 TlModel::LlmEndpoint {
143 provider,
144 model_name,
145 ..
146 } => {
147 let meta = serde_json::json!({
148 "type": "llm",
149 "provider": provider,
150 "model_name": model_name,
151 });
152 fs::write(
153 path.join("metadata.json"),
154 serde_json::to_string_pretty(&meta).unwrap(),
155 )
156 .map_err(|e| format!("Failed to write metadata: {e}"))?;
157 }
158 }
159 Ok(())
160 }
161
162 pub fn load(path: &Path) -> Result<Self, String> {
164 let meta_path = path.join("metadata.json");
165 let meta_str =
166 fs::read_to_string(&meta_path).map_err(|e| format!("Failed to read metadata: {e}"))?;
167 let meta: serde_json::Value =
168 serde_json::from_str(&meta_str).map_err(|e| format!("Invalid metadata: {e}"))?;
169
170 let model_type = meta["type"].as_str().ok_or("Missing 'type' in metadata")?;
171
172 match model_type {
173 "linfa" => {
174 let kind: LinfaKind = serde_json::from_value(meta["kind"].clone())
175 .map_err(|e| format!("Invalid linfa kind: {e}"))?;
176 let metadata: ModelMeta = serde_json::from_value(meta["metadata"].clone())
177 .map_err(|e| format!("Invalid metadata: {e}"))?;
178 let data = fs::read(path.join("model.bin"))
179 .map_err(|e| format!("Failed to read model: {e}"))?;
180 Ok(TlModel::Linfa {
181 kind,
182 data,
183 metadata,
184 })
185 }
186 "onnx" => {
187 let onnx_path = path.join("model.onnx");
188 let metadata: ModelMeta = serde_json::from_value(meta["metadata"].clone())
189 .map_err(|e| format!("Invalid metadata: {e}"))?;
190 Ok(TlModel::Onnx {
191 path: onnx_path,
192 metadata,
193 })
194 }
195 "llm" => {
196 let provider = meta["provider"].as_str().unwrap_or("unknown").to_string();
197 let model_name = meta["model_name"].as_str().unwrap_or("unknown").to_string();
198 Ok(TlModel::LlmEndpoint {
199 provider,
200 model_name,
201 api_key: None,
202 })
203 }
204 _ => Err(format!("Unknown model type: {model_type}")),
205 }
206 }
207
208 pub fn metadata(&self) -> Option<&ModelMeta> {
210 match self {
211 TlModel::Onnx { metadata, .. } => Some(metadata),
212 TlModel::Linfa { metadata, .. } => Some(metadata),
213 TlModel::LlmEndpoint { .. } => None,
214 }
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn test_save_load_linfa_model() {
224 let dir = tempfile::tempdir().unwrap();
225 let model_path = dir.path().join("test.tlmodel");
226
227 let model = TlModel::Linfa {
228 kind: LinfaKind::LinearRegression,
229 data: vec![1, 2, 3, 4],
230 metadata: ModelMeta {
231 name: "test_model".to_string(),
232 version: "0.1.0".to_string(),
233 created_at: "2024-01-01".to_string(),
234 features: vec!["x1".to_string(), "x2".to_string()],
235 target: "y".to_string(),
236 metrics: {
237 let mut m = HashMap::new();
238 m.insert("r2".to_string(), 0.95);
239 m
240 },
241 },
242 };
243
244 model.save(&model_path).unwrap();
245 let loaded = TlModel::load(&model_path).unwrap();
246
247 if let TlModel::Linfa {
248 kind,
249 data,
250 metadata,
251 } = loaded
252 {
253 assert_eq!(kind, LinfaKind::LinearRegression);
254 assert_eq!(data, vec![1, 2, 3, 4]);
255 assert_eq!(metadata.name, "test_model");
256 assert_eq!(metadata.features.len(), 2);
257 assert!((metadata.metrics["r2"] - 0.95).abs() < 1e-10);
258 } else {
259 panic!("Expected Linfa model");
260 }
261 }
262
263 #[test]
264 fn test_model_display() {
265 let model = TlModel::Linfa {
266 kind: LinfaKind::LinearRegression,
267 data: vec![],
268 metadata: ModelMeta {
269 name: "my_model".to_string(),
270 ..Default::default()
271 },
272 };
273 assert_eq!(format!("{model}"), "<model my_model>");
274 }
275}