1use serde::{Deserialize, Serialize};
33use std::collections::HashMap;
34use std::path::Path;
35
36use crate::error::{PachaError, Result};
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ModelManifest {
45 pub base_model: String,
47 pub system_prompt: Option<String>,
49 pub parameters: ManifestParameters,
51 pub template: Option<String>,
53 pub adapter: Option<String>,
55 pub license: Option<String>,
57 pub description: Option<String>,
59 pub metadata: HashMap<String, String>,
61}
62
63impl Default for ModelManifest {
64 fn default() -> Self {
65 Self {
66 base_model: String::new(),
67 system_prompt: None,
68 parameters: ManifestParameters::default(),
69 template: None,
70 adapter: None,
71 license: None,
72 description: None,
73 metadata: HashMap::new(),
74 }
75 }
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ManifestParameters {
81 pub temperature: Option<f32>,
83 pub top_p: Option<f32>,
85 pub top_k: Option<usize>,
87 pub max_tokens: Option<usize>,
89 pub stop: Vec<String>,
91 pub repeat_penalty: Option<f32>,
93 pub repeat_last_n: Option<usize>,
95 pub context_length: Option<usize>,
97 pub seed: Option<u64>,
99}
100
101impl Default for ManifestParameters {
102 fn default() -> Self {
103 Self {
104 temperature: None,
105 top_p: None,
106 top_k: None,
107 max_tokens: None,
108 stop: Vec::new(),
109 repeat_penalty: None,
110 repeat_last_n: None,
111 context_length: None,
112 seed: None,
113 }
114 }
115}
116
117impl ModelManifest {
118 #[must_use]
120 pub fn new(base_model: impl Into<String>) -> Self {
121 Self { base_model: base_model.into(), ..Default::default() }
122 }
123
124 #[must_use]
126 pub fn with_system(mut self, prompt: impl Into<String>) -> Self {
127 self.system_prompt = Some(prompt.into());
128 self
129 }
130
131 #[must_use]
133 pub fn with_temperature(mut self, temp: f32) -> Self {
134 self.parameters.temperature = Some(temp);
135 self
136 }
137
138 #[must_use]
140 pub fn with_top_p(mut self, top_p: f32) -> Self {
141 self.parameters.top_p = Some(top_p);
142 self
143 }
144
145 #[must_use]
147 pub fn with_top_k(mut self, top_k: usize) -> Self {
148 self.parameters.top_k = Some(top_k);
149 self
150 }
151
152 #[must_use]
154 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
155 self.parameters.max_tokens = Some(max_tokens);
156 self
157 }
158
159 #[must_use]
161 pub fn with_stop(mut self, stop: impl Into<String>) -> Self {
162 self.parameters.stop.push(stop.into());
163 self
164 }
165
166 #[must_use]
168 pub fn with_template(mut self, template: impl Into<String>) -> Self {
169 self.template = Some(template.into());
170 self
171 }
172
173 #[must_use]
175 pub fn with_adapter(mut self, adapter: impl Into<String>) -> Self {
176 self.adapter = Some(adapter.into());
177 self
178 }
179
180 #[must_use]
182 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
183 self.description = Some(desc.into());
184 self
185 }
186
187 #[must_use]
189 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
190 self.metadata.insert(key.into(), value.into());
191 self
192 }
193
194 pub fn parse(content: &str) -> Result<Self> {
196 let mut manifest = Self::default();
197
198 for line in content.lines() {
199 let line = line.trim();
200
201 if line.is_empty() || line.starts_with('#') {
203 continue;
204 }
205
206 let (directive, value) = if let Some(idx) = line.find(char::is_whitespace) {
208 let (d, v) = line.split_at(idx);
209 (d.to_uppercase(), v.trim())
210 } else {
211 (line.to_uppercase(), "")
212 };
213
214 match directive.as_str() {
215 "FROM" => {
216 if value.is_empty() {
217 return Err(PachaError::Validation(
218 "FROM requires a model reference".to_string(),
219 ));
220 }
221 manifest.base_model = value.to_string();
222 }
223 "SYSTEM" => {
224 manifest.system_prompt = Some(value.to_string());
225 }
226 "PARAMETER" => {
227 parse_parameter(&mut manifest.parameters, value)?;
228 }
229 "TEMPLATE" => {
230 let template = value.trim_matches('"').trim_matches('\'');
232 manifest.template = Some(template.to_string());
233 }
234 "ADAPTER" => {
235 manifest.adapter = Some(value.to_string());
236 }
237 "LICENSE" => {
238 manifest.license = Some(value.to_string());
239 }
240 "MESSAGE" => {
241 manifest.metadata.insert("message".to_string(), value.to_string());
243 }
244 _ => {
245 manifest.metadata.insert(directive.to_lowercase(), value.to_string());
247 }
248 }
249 }
250
251 if manifest.base_model.is_empty() {
252 return Err(PachaError::Validation("Modelfile must have FROM directive".to_string()));
253 }
254
255 Ok(manifest)
256 }
257
258 pub fn load(path: &Path) -> Result<Self> {
260 let content = std::fs::read_to_string(path).map_err(|e| {
261 PachaError::Io(std::io::Error::new(
262 e.kind(),
263 format!("Failed to read {}: {}", path.display(), e),
264 ))
265 })?;
266 Self::parse(&content)
267 }
268
269 pub fn save(&self, path: &Path) -> Result<()> {
271 let content = self.to_modelfile();
272 std::fs::write(path, content).map_err(|e| {
273 PachaError::Io(std::io::Error::new(
274 e.kind(),
275 format!("Failed to write {}: {}", path.display(), e),
276 ))
277 })
278 }
279
280 #[must_use]
282 pub fn to_modelfile(&self) -> String {
283 let mut lines = Vec::new();
284
285 lines.push(format!("FROM {}", self.base_model));
287
288 if let Some(ref system) = self.system_prompt {
290 lines.push(format!("SYSTEM {}", system));
291 }
292
293 if let Some(temp) = self.parameters.temperature {
295 lines.push(format!("PARAMETER temperature {}", temp));
296 }
297 if let Some(top_p) = self.parameters.top_p {
298 lines.push(format!("PARAMETER top_p {}", top_p));
299 }
300 if let Some(top_k) = self.parameters.top_k {
301 lines.push(format!("PARAMETER top_k {}", top_k));
302 }
303 if let Some(max_tokens) = self.parameters.max_tokens {
304 lines.push(format!("PARAMETER max_tokens {}", max_tokens));
305 }
306 for stop in &self.parameters.stop {
307 lines.push(format!("PARAMETER stop \"{}\"", stop));
308 }
309 if let Some(repeat_penalty) = self.parameters.repeat_penalty {
310 lines.push(format!("PARAMETER repeat_penalty {}", repeat_penalty));
311 }
312 if let Some(context_length) = self.parameters.context_length {
313 lines.push(format!("PARAMETER context_length {}", context_length));
314 }
315 if let Some(seed) = self.parameters.seed {
316 lines.push(format!("PARAMETER seed {}", seed));
317 }
318
319 if let Some(ref template) = self.template {
321 lines.push(format!("TEMPLATE \"{}\"", template));
322 }
323
324 if let Some(ref adapter) = self.adapter {
326 lines.push(format!("ADAPTER {}", adapter));
327 }
328
329 if let Some(ref license) = self.license {
331 lines.push(format!("LICENSE {}", license));
332 }
333
334 lines.join("\n")
335 }
336
337 pub fn to_json(&self) -> Result<String> {
339 serde_json::to_string_pretty(self)
340 .map_err(|e| PachaError::Validation(format!("Failed to serialize manifest: {}", e)))
341 }
342
343 pub fn from_json(json: &str) -> Result<Self> {
345 serde_json::from_str(json)
346 .map_err(|e| PachaError::Validation(format!("Failed to parse manifest JSON: {}", e)))
347 }
348}
349
350fn parse_parameter(params: &mut ManifestParameters, value: &str) -> Result<()> {
352 let parts: Vec<&str> = value.splitn(2, char::is_whitespace).collect();
353 if parts.len() != 2 {
354 return Err(PachaError::Validation(format!("Invalid PARAMETER syntax: {}", value)));
355 }
356 let (name, val) = (parts[0].to_lowercase(), parts[1].trim());
357 apply_parameter(params, &name, val)
358}
359
360fn apply_parameter(params: &mut ManifestParameters, name: &str, val: &str) -> Result<()> {
361 match name {
362 "temperature" => params.temperature = Some(parse_named(val, "temperature")?),
363 "top_p" => params.top_p = Some(parse_named(val, "top_p")?),
364 "top_k" => params.top_k = Some(parse_named(val, "top_k")?),
365 "max_tokens" | "num_predict" => params.max_tokens = Some(parse_named(val, "max_tokens")?),
366 "stop" => {
367 let stop = val.trim_matches('"').trim_matches('\'');
368 params.stop.push(stop.to_string());
369 }
370 "repeat_penalty" => params.repeat_penalty = Some(parse_named(val, "repeat_penalty")?),
371 "repeat_last_n" => params.repeat_last_n = Some(parse_named(val, "repeat_last_n")?),
372 "context_length" | "num_ctx" => {
373 params.context_length = Some(parse_named(val, "context_length")?);
374 }
375 "seed" => params.seed = Some(parse_named(val, "seed")?),
376 _ => {
377 }
379 }
380 Ok(())
381}
382
383fn parse_named<T: std::str::FromStr>(val: &str, field: &str) -> Result<T> {
384 val.parse().map_err(|_| PachaError::Validation(format!("Invalid {field}: {val}")))
385}
386
387#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
400 fn test_parse_minimal() {
401 let manifest = ModelManifest::parse("FROM llama3").unwrap();
402 assert_eq!(manifest.base_model, "llama3");
403 }
404
405 #[test]
406 fn test_parse_with_system() {
407 let manifest = ModelManifest::parse(
408 r#"
409 FROM llama3:8b
410 SYSTEM You are a helpful assistant.
411 "#,
412 )
413 .unwrap();
414
415 assert_eq!(manifest.base_model, "llama3:8b");
416 assert_eq!(manifest.system_prompt, Some("You are a helpful assistant.".to_string()));
417 }
418
419 #[test]
420 fn test_parse_with_parameters() {
421 let manifest = ModelManifest::parse(
422 r#"
423 FROM mistral
424 PARAMETER temperature 0.7
425 PARAMETER top_p 0.9
426 PARAMETER top_k 40
427 PARAMETER max_tokens 256
428 "#,
429 )
430 .unwrap();
431
432 assert_eq!(manifest.parameters.temperature, Some(0.7));
433 assert_eq!(manifest.parameters.top_p, Some(0.9));
434 assert_eq!(manifest.parameters.top_k, Some(40));
435 assert_eq!(manifest.parameters.max_tokens, Some(256));
436 }
437
438 #[test]
439 fn test_parse_with_stop_sequences() {
440 let manifest = ModelManifest::parse(
441 r#"
442 FROM llama3
443 PARAMETER stop "<|endoftext|>"
444 PARAMETER stop "User:"
445 "#,
446 )
447 .unwrap();
448
449 assert_eq!(manifest.parameters.stop.len(), 2);
450 assert!(manifest.parameters.stop.contains(&"<|endoftext|>".to_string()));
451 assert!(manifest.parameters.stop.contains(&"User:".to_string()));
452 }
453
454 #[test]
455 fn test_parse_with_template() {
456 let manifest = ModelManifest::parse(
457 r#"
458 FROM llama3
459 TEMPLATE "{{ .System }}\nUser: {{ .Prompt }}\nAssistant:"
460 "#,
461 )
462 .unwrap();
463
464 assert!(manifest.template.is_some());
465 assert!(manifest.template.as_ref().unwrap().contains("System"));
466 }
467
468 #[test]
469 fn test_parse_with_adapter() {
470 let manifest = ModelManifest::parse(
471 r#"
472 FROM llama3:8b
473 ADAPTER /path/to/lora.safetensors
474 "#,
475 )
476 .unwrap();
477
478 assert_eq!(manifest.adapter, Some("/path/to/lora.safetensors".to_string()));
479 }
480
481 #[test]
482 fn test_parse_with_comments() {
483 let manifest = ModelManifest::parse(
484 r#"
485 # This is a comment
486 FROM llama3
487 # Another comment
488 SYSTEM Be helpful
489 "#,
490 )
491 .unwrap();
492
493 assert_eq!(manifest.base_model, "llama3");
494 assert!(manifest.system_prompt.is_some());
495 }
496
497 #[test]
498 fn test_parse_missing_from() {
499 let result = ModelManifest::parse("SYSTEM You are helpful.");
500 assert!(result.is_err());
501 }
502
503 #[test]
504 fn test_parse_empty_from() {
505 let result = ModelManifest::parse("FROM");
506 assert!(result.is_err());
507 }
508
509 #[test]
514 fn test_builder() {
515 let manifest = ModelManifest::new("llama3:8b")
516 .with_system("You are a coding assistant.")
517 .with_temperature(0.8)
518 .with_top_p(0.95)
519 .with_max_tokens(1024)
520 .with_stop("<|end|>")
521 .with_description("My custom model");
522
523 assert_eq!(manifest.base_model, "llama3:8b");
524 assert!(manifest.system_prompt.is_some());
525 assert_eq!(manifest.parameters.temperature, Some(0.8));
526 assert_eq!(manifest.parameters.top_p, Some(0.95));
527 assert_eq!(manifest.parameters.max_tokens, Some(1024));
528 assert_eq!(manifest.parameters.stop.len(), 1);
529 assert!(manifest.description.is_some());
530 }
531
532 #[test]
533 fn test_builder_with_metadata() {
534 let manifest = ModelManifest::new("llama3")
535 .with_metadata("author", "test")
536 .with_metadata("version", "1.0");
537
538 assert_eq!(manifest.metadata.get("author"), Some(&"test".to_string()));
539 assert_eq!(manifest.metadata.get("version"), Some(&"1.0".to_string()));
540 }
541
542 #[test]
547 fn test_to_modelfile() {
548 let manifest =
549 ModelManifest::new("llama3:8b").with_system("Be helpful").with_temperature(0.7);
550
551 let modelfile = manifest.to_modelfile();
552 assert!(modelfile.contains("FROM llama3:8b"));
553 assert!(modelfile.contains("SYSTEM Be helpful"));
554 assert!(modelfile.contains("PARAMETER temperature 0.7"));
555 }
556
557 #[test]
558 fn test_roundtrip() {
559 let original = ModelManifest::new("mixtral:8x7b")
560 .with_system("You are an expert.")
561 .with_temperature(0.9)
562 .with_top_k(50)
563 .with_max_tokens(2048);
564
565 let modelfile = original.to_modelfile();
566 let parsed = ModelManifest::parse(&modelfile).unwrap();
567
568 assert_eq!(parsed.base_model, original.base_model);
569 assert_eq!(parsed.system_prompt, original.system_prompt);
570 assert_eq!(parsed.parameters.temperature, original.parameters.temperature);
571 assert_eq!(parsed.parameters.top_k, original.parameters.top_k);
572 assert_eq!(parsed.parameters.max_tokens, original.parameters.max_tokens);
573 }
574
575 #[test]
576 fn test_json_roundtrip() {
577 let original = ModelManifest::new("llama3").with_system("Test").with_temperature(0.5);
578
579 let json = original.to_json().unwrap();
580 let parsed = ModelManifest::from_json(&json).unwrap();
581
582 assert_eq!(parsed.base_model, original.base_model);
583 assert_eq!(parsed.system_prompt, original.system_prompt);
584 }
585
586 #[test]
591 fn test_parse_context_length_alias() {
592 let manifest = ModelManifest::parse(
593 r#"
594 FROM llama3
595 PARAMETER num_ctx 4096
596 "#,
597 )
598 .unwrap();
599
600 assert_eq!(manifest.parameters.context_length, Some(4096));
601 }
602
603 #[test]
604 fn test_parse_max_tokens_alias() {
605 let manifest = ModelManifest::parse(
606 r#"
607 FROM llama3
608 PARAMETER num_predict 512
609 "#,
610 )
611 .unwrap();
612
613 assert_eq!(manifest.parameters.max_tokens, Some(512));
614 }
615
616 #[test]
617 fn test_parse_repeat_penalty() {
618 let manifest = ModelManifest::parse(
619 r#"
620 FROM llama3
621 PARAMETER repeat_penalty 1.1
622 PARAMETER repeat_last_n 64
623 "#,
624 )
625 .unwrap();
626
627 assert_eq!(manifest.parameters.repeat_penalty, Some(1.1));
628 assert_eq!(manifest.parameters.repeat_last_n, Some(64));
629 }
630
631 #[test]
632 fn test_parse_seed() {
633 let manifest = ModelManifest::parse(
634 r#"
635 FROM llama3
636 PARAMETER seed 42
637 "#,
638 )
639 .unwrap();
640
641 assert_eq!(manifest.parameters.seed, Some(42));
642 }
643
644 #[test]
645 fn test_invalid_parameter_value() {
646 let result = ModelManifest::parse(
647 r#"
648 FROM llama3
649 PARAMETER temperature not_a_number
650 "#,
651 );
652 assert!(result.is_err());
653 }
654
655 #[test]
660 fn test_default_parameters() {
661 let params = ManifestParameters::default();
662 assert!(params.temperature.is_none());
663 assert!(params.top_p.is_none());
664 assert!(params.stop.is_empty());
665 }
666
667 #[test]
668 fn test_default_manifest() {
669 let manifest = ModelManifest::default();
670 assert!(manifest.base_model.is_empty());
671 assert!(manifest.system_prompt.is_none());
672 }
673}