1use serde::{Deserialize, Serialize};
33use std::collections::HashMap;
34use std::path::Path;
35
36use crate::error::{PachaError, Result};
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ModelManifest {
45 pub base_model: String,
47 pub system_prompt: Option<String>,
49 pub parameters: ManifestParameters,
51 pub template: Option<String>,
53 pub adapter: Option<String>,
55 pub license: Option<String>,
57 pub description: Option<String>,
59 pub metadata: HashMap<String, String>,
61}
62
63impl Default for ModelManifest {
64 fn default() -> Self {
65 Self {
66 base_model: String::new(),
67 system_prompt: None,
68 parameters: ManifestParameters::default(),
69 template: None,
70 adapter: None,
71 license: None,
72 description: None,
73 metadata: HashMap::new(),
74 }
75 }
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ManifestParameters {
81 pub temperature: Option<f32>,
83 pub top_p: Option<f32>,
85 pub top_k: Option<usize>,
87 pub max_tokens: Option<usize>,
89 pub stop: Vec<String>,
91 pub repeat_penalty: Option<f32>,
93 pub repeat_last_n: Option<usize>,
95 pub context_length: Option<usize>,
97 pub seed: Option<u64>,
99}
100
101impl Default for ManifestParameters {
102 fn default() -> Self {
103 Self {
104 temperature: None,
105 top_p: None,
106 top_k: None,
107 max_tokens: None,
108 stop: Vec::new(),
109 repeat_penalty: None,
110 repeat_last_n: None,
111 context_length: None,
112 seed: None,
113 }
114 }
115}
116
117impl ModelManifest {
118 #[must_use]
120 pub fn new(base_model: impl Into<String>) -> Self {
121 Self { base_model: base_model.into(), ..Default::default() }
122 }
123
124 #[must_use]
126 pub fn with_system(mut self, prompt: impl Into<String>) -> Self {
127 self.system_prompt = Some(prompt.into());
128 self
129 }
130
131 #[must_use]
133 pub fn with_temperature(mut self, temp: f32) -> Self {
134 self.parameters.temperature = Some(temp);
135 self
136 }
137
138 #[must_use]
140 pub fn with_top_p(mut self, top_p: f32) -> Self {
141 self.parameters.top_p = Some(top_p);
142 self
143 }
144
145 #[must_use]
147 pub fn with_top_k(mut self, top_k: usize) -> Self {
148 self.parameters.top_k = Some(top_k);
149 self
150 }
151
152 #[must_use]
154 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
155 self.parameters.max_tokens = Some(max_tokens);
156 self
157 }
158
159 #[must_use]
161 pub fn with_stop(mut self, stop: impl Into<String>) -> Self {
162 self.parameters.stop.push(stop.into());
163 self
164 }
165
166 #[must_use]
168 pub fn with_template(mut self, template: impl Into<String>) -> Self {
169 self.template = Some(template.into());
170 self
171 }
172
173 #[must_use]
175 pub fn with_adapter(mut self, adapter: impl Into<String>) -> Self {
176 self.adapter = Some(adapter.into());
177 self
178 }
179
180 #[must_use]
182 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
183 self.description = Some(desc.into());
184 self
185 }
186
187 #[must_use]
189 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
190 self.metadata.insert(key.into(), value.into());
191 self
192 }
193
194 pub fn parse(content: &str) -> Result<Self> {
196 let mut manifest = Self::default();
197
198 for line in content.lines() {
199 let line = line.trim();
200
201 if line.is_empty() || line.starts_with('#') {
203 continue;
204 }
205
206 let (directive, value) = if let Some(idx) = line.find(char::is_whitespace) {
208 let (d, v) = line.split_at(idx);
209 (d.to_uppercase(), v.trim())
210 } else {
211 (line.to_uppercase(), "")
212 };
213
214 match directive.as_str() {
215 "FROM" => {
216 if value.is_empty() {
217 return Err(PachaError::Validation(
218 "FROM requires a model reference".to_string(),
219 ));
220 }
221 manifest.base_model = value.to_string();
222 }
223 "SYSTEM" => {
224 manifest.system_prompt = Some(value.to_string());
225 }
226 "PARAMETER" => {
227 parse_parameter(&mut manifest.parameters, value)?;
228 }
229 "TEMPLATE" => {
230 let template = value.trim_matches('"').trim_matches('\'');
232 manifest.template = Some(template.to_string());
233 }
234 "ADAPTER" => {
235 manifest.adapter = Some(value.to_string());
236 }
237 "LICENSE" => {
238 manifest.license = Some(value.to_string());
239 }
240 "MESSAGE" => {
241 manifest.metadata.insert("message".to_string(), value.to_string());
243 }
244 _ => {
245 manifest.metadata.insert(directive.to_lowercase(), value.to_string());
247 }
248 }
249 }
250
251 if manifest.base_model.is_empty() {
252 return Err(PachaError::Validation("Modelfile must have FROM directive".to_string()));
253 }
254
255 Ok(manifest)
256 }
257
258 pub fn load(path: &Path) -> Result<Self> {
260 let content = std::fs::read_to_string(path).map_err(|e| {
261 PachaError::Io(std::io::Error::new(
262 e.kind(),
263 format!("Failed to read {}: {}", path.display(), e),
264 ))
265 })?;
266 Self::parse(&content)
267 }
268
269 pub fn save(&self, path: &Path) -> Result<()> {
271 let content = self.to_modelfile();
272 std::fs::write(path, content).map_err(|e| {
273 PachaError::Io(std::io::Error::new(
274 e.kind(),
275 format!("Failed to write {}: {}", path.display(), e),
276 ))
277 })
278 }
279
280 #[must_use]
282 pub fn to_modelfile(&self) -> String {
283 let mut lines = Vec::new();
284
285 lines.push(format!("FROM {}", self.base_model));
287
288 if let Some(ref system) = self.system_prompt {
290 lines.push(format!("SYSTEM {}", system));
291 }
292
293 if let Some(temp) = self.parameters.temperature {
295 lines.push(format!("PARAMETER temperature {}", temp));
296 }
297 if let Some(top_p) = self.parameters.top_p {
298 lines.push(format!("PARAMETER top_p {}", top_p));
299 }
300 if let Some(top_k) = self.parameters.top_k {
301 lines.push(format!("PARAMETER top_k {}", top_k));
302 }
303 if let Some(max_tokens) = self.parameters.max_tokens {
304 lines.push(format!("PARAMETER max_tokens {}", max_tokens));
305 }
306 for stop in &self.parameters.stop {
307 lines.push(format!("PARAMETER stop \"{}\"", stop));
308 }
309 if let Some(repeat_penalty) = self.parameters.repeat_penalty {
310 lines.push(format!("PARAMETER repeat_penalty {}", repeat_penalty));
311 }
312 if let Some(context_length) = self.parameters.context_length {
313 lines.push(format!("PARAMETER context_length {}", context_length));
314 }
315 if let Some(seed) = self.parameters.seed {
316 lines.push(format!("PARAMETER seed {}", seed));
317 }
318
319 if let Some(ref template) = self.template {
321 lines.push(format!("TEMPLATE \"{}\"", template));
322 }
323
324 if let Some(ref adapter) = self.adapter {
326 lines.push(format!("ADAPTER {}", adapter));
327 }
328
329 if let Some(ref license) = self.license {
331 lines.push(format!("LICENSE {}", license));
332 }
333
334 lines.join("\n")
335 }
336
337 pub fn to_json(&self) -> Result<String> {
339 serde_json::to_string_pretty(self)
340 .map_err(|e| PachaError::Validation(format!("Failed to serialize manifest: {}", e)))
341 }
342
343 pub fn from_json(json: &str) -> Result<Self> {
345 serde_json::from_str(json)
346 .map_err(|e| PachaError::Validation(format!("Failed to parse manifest JSON: {}", e)))
347 }
348}
349
350fn parse_parameter(params: &mut ManifestParameters, value: &str) -> Result<()> {
352 let parts: Vec<&str> = value.splitn(2, char::is_whitespace).collect();
353 if parts.len() != 2 {
354 return Err(PachaError::Validation(format!("Invalid PARAMETER syntax: {}", value)));
355 }
356
357 let (name, val) = (parts[0].to_lowercase(), parts[1].trim());
358
359 match name.as_str() {
360 "temperature" => {
361 params.temperature =
362 Some(val.parse().map_err(|_| {
363 PachaError::Validation(format!("Invalid temperature: {}", val))
364 })?);
365 }
366 "top_p" => {
367 params.top_p = Some(
368 val.parse()
369 .map_err(|_| PachaError::Validation(format!("Invalid top_p: {}", val)))?,
370 );
371 }
372 "top_k" => {
373 params.top_k = Some(
374 val.parse()
375 .map_err(|_| PachaError::Validation(format!("Invalid top_k: {}", val)))?,
376 );
377 }
378 "max_tokens" | "num_predict" => {
379 params.max_tokens = Some(
380 val.parse()
381 .map_err(|_| PachaError::Validation(format!("Invalid max_tokens: {}", val)))?,
382 );
383 }
384 "stop" => {
385 let stop = val.trim_matches('"').trim_matches('\'');
386 params.stop.push(stop.to_string());
387 }
388 "repeat_penalty" => {
389 params.repeat_penalty =
390 Some(val.parse().map_err(|_| {
391 PachaError::Validation(format!("Invalid repeat_penalty: {}", val))
392 })?);
393 }
394 "repeat_last_n" => {
395 params.repeat_last_n =
396 Some(val.parse().map_err(|_| {
397 PachaError::Validation(format!("Invalid repeat_last_n: {}", val))
398 })?);
399 }
400 "context_length" | "num_ctx" => {
401 params.context_length =
402 Some(val.parse().map_err(|_| {
403 PachaError::Validation(format!("Invalid context_length: {}", val))
404 })?);
405 }
406 "seed" => {
407 params.seed = Some(
408 val.parse()
409 .map_err(|_| PachaError::Validation(format!("Invalid seed: {}", val)))?,
410 );
411 }
412 _ => {
413 }
415 }
416
417 Ok(())
418}
419
420#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
433 fn test_parse_minimal() {
434 let manifest = ModelManifest::parse("FROM llama3").unwrap();
435 assert_eq!(manifest.base_model, "llama3");
436 }
437
438 #[test]
439 fn test_parse_with_system() {
440 let manifest = ModelManifest::parse(
441 r#"
442 FROM llama3:8b
443 SYSTEM You are a helpful assistant.
444 "#,
445 )
446 .unwrap();
447
448 assert_eq!(manifest.base_model, "llama3:8b");
449 assert_eq!(manifest.system_prompt, Some("You are a helpful assistant.".to_string()));
450 }
451
452 #[test]
453 fn test_parse_with_parameters() {
454 let manifest = ModelManifest::parse(
455 r#"
456 FROM mistral
457 PARAMETER temperature 0.7
458 PARAMETER top_p 0.9
459 PARAMETER top_k 40
460 PARAMETER max_tokens 256
461 "#,
462 )
463 .unwrap();
464
465 assert_eq!(manifest.parameters.temperature, Some(0.7));
466 assert_eq!(manifest.parameters.top_p, Some(0.9));
467 assert_eq!(manifest.parameters.top_k, Some(40));
468 assert_eq!(manifest.parameters.max_tokens, Some(256));
469 }
470
471 #[test]
472 fn test_parse_with_stop_sequences() {
473 let manifest = ModelManifest::parse(
474 r#"
475 FROM llama3
476 PARAMETER stop "<|endoftext|>"
477 PARAMETER stop "User:"
478 "#,
479 )
480 .unwrap();
481
482 assert_eq!(manifest.parameters.stop.len(), 2);
483 assert!(manifest.parameters.stop.contains(&"<|endoftext|>".to_string()));
484 assert!(manifest.parameters.stop.contains(&"User:".to_string()));
485 }
486
487 #[test]
488 fn test_parse_with_template() {
489 let manifest = ModelManifest::parse(
490 r#"
491 FROM llama3
492 TEMPLATE "{{ .System }}\nUser: {{ .Prompt }}\nAssistant:"
493 "#,
494 )
495 .unwrap();
496
497 assert!(manifest.template.is_some());
498 assert!(manifest.template.as_ref().unwrap().contains("System"));
499 }
500
501 #[test]
502 fn test_parse_with_adapter() {
503 let manifest = ModelManifest::parse(
504 r#"
505 FROM llama3:8b
506 ADAPTER /path/to/lora.safetensors
507 "#,
508 )
509 .unwrap();
510
511 assert_eq!(manifest.adapter, Some("/path/to/lora.safetensors".to_string()));
512 }
513
514 #[test]
515 fn test_parse_with_comments() {
516 let manifest = ModelManifest::parse(
517 r#"
518 # This is a comment
519 FROM llama3
520 # Another comment
521 SYSTEM Be helpful
522 "#,
523 )
524 .unwrap();
525
526 assert_eq!(manifest.base_model, "llama3");
527 assert!(manifest.system_prompt.is_some());
528 }
529
530 #[test]
531 fn test_parse_missing_from() {
532 let result = ModelManifest::parse("SYSTEM You are helpful.");
533 assert!(result.is_err());
534 }
535
536 #[test]
537 fn test_parse_empty_from() {
538 let result = ModelManifest::parse("FROM");
539 assert!(result.is_err());
540 }
541
542 #[test]
547 fn test_builder() {
548 let manifest = ModelManifest::new("llama3:8b")
549 .with_system("You are a coding assistant.")
550 .with_temperature(0.8)
551 .with_top_p(0.95)
552 .with_max_tokens(1024)
553 .with_stop("<|end|>")
554 .with_description("My custom model");
555
556 assert_eq!(manifest.base_model, "llama3:8b");
557 assert!(manifest.system_prompt.is_some());
558 assert_eq!(manifest.parameters.temperature, Some(0.8));
559 assert_eq!(manifest.parameters.top_p, Some(0.95));
560 assert_eq!(manifest.parameters.max_tokens, Some(1024));
561 assert_eq!(manifest.parameters.stop.len(), 1);
562 assert!(manifest.description.is_some());
563 }
564
565 #[test]
566 fn test_builder_with_metadata() {
567 let manifest = ModelManifest::new("llama3")
568 .with_metadata("author", "test")
569 .with_metadata("version", "1.0");
570
571 assert_eq!(manifest.metadata.get("author"), Some(&"test".to_string()));
572 assert_eq!(manifest.metadata.get("version"), Some(&"1.0".to_string()));
573 }
574
575 #[test]
580 fn test_to_modelfile() {
581 let manifest =
582 ModelManifest::new("llama3:8b").with_system("Be helpful").with_temperature(0.7);
583
584 let modelfile = manifest.to_modelfile();
585 assert!(modelfile.contains("FROM llama3:8b"));
586 assert!(modelfile.contains("SYSTEM Be helpful"));
587 assert!(modelfile.contains("PARAMETER temperature 0.7"));
588 }
589
590 #[test]
591 fn test_roundtrip() {
592 let original = ModelManifest::new("mixtral:8x7b")
593 .with_system("You are an expert.")
594 .with_temperature(0.9)
595 .with_top_k(50)
596 .with_max_tokens(2048);
597
598 let modelfile = original.to_modelfile();
599 let parsed = ModelManifest::parse(&modelfile).unwrap();
600
601 assert_eq!(parsed.base_model, original.base_model);
602 assert_eq!(parsed.system_prompt, original.system_prompt);
603 assert_eq!(parsed.parameters.temperature, original.parameters.temperature);
604 assert_eq!(parsed.parameters.top_k, original.parameters.top_k);
605 assert_eq!(parsed.parameters.max_tokens, original.parameters.max_tokens);
606 }
607
608 #[test]
609 fn test_json_roundtrip() {
610 let original = ModelManifest::new("llama3").with_system("Test").with_temperature(0.5);
611
612 let json = original.to_json().unwrap();
613 let parsed = ModelManifest::from_json(&json).unwrap();
614
615 assert_eq!(parsed.base_model, original.base_model);
616 assert_eq!(parsed.system_prompt, original.system_prompt);
617 }
618
619 #[test]
624 fn test_parse_context_length_alias() {
625 let manifest = ModelManifest::parse(
626 r#"
627 FROM llama3
628 PARAMETER num_ctx 4096
629 "#,
630 )
631 .unwrap();
632
633 assert_eq!(manifest.parameters.context_length, Some(4096));
634 }
635
636 #[test]
637 fn test_parse_max_tokens_alias() {
638 let manifest = ModelManifest::parse(
639 r#"
640 FROM llama3
641 PARAMETER num_predict 512
642 "#,
643 )
644 .unwrap();
645
646 assert_eq!(manifest.parameters.max_tokens, Some(512));
647 }
648
649 #[test]
650 fn test_parse_repeat_penalty() {
651 let manifest = ModelManifest::parse(
652 r#"
653 FROM llama3
654 PARAMETER repeat_penalty 1.1
655 PARAMETER repeat_last_n 64
656 "#,
657 )
658 .unwrap();
659
660 assert_eq!(manifest.parameters.repeat_penalty, Some(1.1));
661 assert_eq!(manifest.parameters.repeat_last_n, Some(64));
662 }
663
664 #[test]
665 fn test_parse_seed() {
666 let manifest = ModelManifest::parse(
667 r#"
668 FROM llama3
669 PARAMETER seed 42
670 "#,
671 )
672 .unwrap();
673
674 assert_eq!(manifest.parameters.seed, Some(42));
675 }
676
677 #[test]
678 fn test_invalid_parameter_value() {
679 let result = ModelManifest::parse(
680 r#"
681 FROM llama3
682 PARAMETER temperature not_a_number
683 "#,
684 );
685 assert!(result.is_err());
686 }
687
688 #[test]
693 fn test_default_parameters() {
694 let params = ManifestParameters::default();
695 assert!(params.temperature.is_none());
696 assert!(params.top_p.is_none());
697 assert!(params.stop.is_empty());
698 }
699
700 #[test]
701 fn test_default_manifest() {
702 let manifest = ModelManifest::default();
703 assert!(manifest.base_model.is_empty());
704 assert!(manifest.system_prompt.is_none());
705 }
706}