Skip to main content

pacha/
manifest.rs

1//! Model Manifest (Modelfile) Support
2//!
3//! Provides a Modelfile-like configuration for custom model definitions,
4//! similar to ollama's Modelfile format.
5//!
6//! ## Modelfile Format
7//!
8//! ```text
9//! FROM llama3:8b
10//! SYSTEM You are a helpful coding assistant.
11//! PARAMETER temperature 0.7
12//! PARAMETER top_p 0.9
13//! PARAMETER stop "<|endoftext|>"
14//! TEMPLATE "{{ .System }}\nUser: {{ .Prompt }}\nAssistant:"
15//! ```
16//!
17//! ## Example
18//!
19//! ```rust,ignore
20//! use pacha::manifest::ModelManifest;
21//!
22//! let manifest = ModelManifest::parse(r#"
23//!     FROM llama3:8b
24//!     SYSTEM You are helpful.
25//!     PARAMETER temperature 0.7
26//! "#)?;
27//!
28//! println!("Base: {}", manifest.base_model);
29//! println!("System: {:?}", manifest.system_prompt);
30//! ```
31
32use serde::{Deserialize, Serialize};
33use std::collections::HashMap;
34use std::path::Path;
35
36use crate::error::{PachaError, Result};
37
38// ============================================================================
39// MANIFEST-001: Model Manifest
40// ============================================================================
41
42/// Model manifest defining a custom model configuration
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ModelManifest {
45    /// Base model reference (FROM directive)
46    pub base_model: String,
47    /// System prompt (SYSTEM directive)
48    pub system_prompt: Option<String>,
49    /// Generation parameters (PARAMETER directives)
50    pub parameters: ManifestParameters,
51    /// Custom prompt template (TEMPLATE directive)
52    pub template: Option<String>,
53    /// Model adapter/LoRA path (ADAPTER directive)
54    pub adapter: Option<String>,
55    /// License information (LICENSE directive)
56    pub license: Option<String>,
57    /// Model description
58    pub description: Option<String>,
59    /// Custom metadata
60    pub metadata: HashMap<String, String>,
61}
62
63impl Default for ModelManifest {
64    fn default() -> Self {
65        Self {
66            base_model: String::new(),
67            system_prompt: None,
68            parameters: ManifestParameters::default(),
69            template: None,
70            adapter: None,
71            license: None,
72            description: None,
73            metadata: HashMap::new(),
74        }
75    }
76}
77
78/// Generation parameters from manifest
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ManifestParameters {
81    /// Sampling temperature
82    pub temperature: Option<f32>,
83    /// Top-p (nucleus) sampling
84    pub top_p: Option<f32>,
85    /// Top-k sampling
86    pub top_k: Option<usize>,
87    /// Maximum tokens to generate
88    pub max_tokens: Option<usize>,
89    /// Stop sequences
90    pub stop: Vec<String>,
91    /// Repetition penalty
92    pub repeat_penalty: Option<f32>,
93    /// Number of tokens to consider for repetition penalty
94    pub repeat_last_n: Option<usize>,
95    /// Context window size
96    pub context_length: Option<usize>,
97    /// Seed for reproducibility
98    pub seed: Option<u64>,
99}
100
101impl Default for ManifestParameters {
102    fn default() -> Self {
103        Self {
104            temperature: None,
105            top_p: None,
106            top_k: None,
107            max_tokens: None,
108            stop: Vec::new(),
109            repeat_penalty: None,
110            repeat_last_n: None,
111            context_length: None,
112            seed: None,
113        }
114    }
115}
116
117impl ModelManifest {
118    /// Create a new manifest with a base model
119    #[must_use]
120    pub fn new(base_model: impl Into<String>) -> Self {
121        Self { base_model: base_model.into(), ..Default::default() }
122    }
123
124    /// Set system prompt
125    #[must_use]
126    pub fn with_system(mut self, prompt: impl Into<String>) -> Self {
127        self.system_prompt = Some(prompt.into());
128        self
129    }
130
131    /// Set temperature
132    #[must_use]
133    pub fn with_temperature(mut self, temp: f32) -> Self {
134        self.parameters.temperature = Some(temp);
135        self
136    }
137
138    /// Set top_p
139    #[must_use]
140    pub fn with_top_p(mut self, top_p: f32) -> Self {
141        self.parameters.top_p = Some(top_p);
142        self
143    }
144
145    /// Set top_k
146    #[must_use]
147    pub fn with_top_k(mut self, top_k: usize) -> Self {
148        self.parameters.top_k = Some(top_k);
149        self
150    }
151
152    /// Set max tokens
153    #[must_use]
154    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
155        self.parameters.max_tokens = Some(max_tokens);
156        self
157    }
158
159    /// Add stop sequence
160    #[must_use]
161    pub fn with_stop(mut self, stop: impl Into<String>) -> Self {
162        self.parameters.stop.push(stop.into());
163        self
164    }
165
166    /// Set template
167    #[must_use]
168    pub fn with_template(mut self, template: impl Into<String>) -> Self {
169        self.template = Some(template.into());
170        self
171    }
172
173    /// Set adapter path
174    #[must_use]
175    pub fn with_adapter(mut self, adapter: impl Into<String>) -> Self {
176        self.adapter = Some(adapter.into());
177        self
178    }
179
180    /// Set description
181    #[must_use]
182    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
183        self.description = Some(desc.into());
184        self
185    }
186
187    /// Add metadata
188    #[must_use]
189    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
190        self.metadata.insert(key.into(), value.into());
191        self
192    }
193
194    /// Parse a Modelfile-format string
195    pub fn parse(content: &str) -> Result<Self> {
196        let mut manifest = Self::default();
197
198        for line in content.lines() {
199            let line = line.trim();
200
201            // Skip empty lines and comments
202            if line.is_empty() || line.starts_with('#') {
203                continue;
204            }
205
206            // Parse directive
207            let (directive, value) = if let Some(idx) = line.find(char::is_whitespace) {
208                let (d, v) = line.split_at(idx);
209                (d.to_uppercase(), v.trim())
210            } else {
211                (line.to_uppercase(), "")
212            };
213
214            match directive.as_str() {
215                "FROM" => {
216                    if value.is_empty() {
217                        return Err(PachaError::Validation(
218                            "FROM requires a model reference".to_string(),
219                        ));
220                    }
221                    manifest.base_model = value.to_string();
222                }
223                "SYSTEM" => {
224                    manifest.system_prompt = Some(value.to_string());
225                }
226                "PARAMETER" => {
227                    parse_parameter(&mut manifest.parameters, value)?;
228                }
229                "TEMPLATE" => {
230                    // Template can be multi-line with quotes
231                    let template = value.trim_matches('"').trim_matches('\'');
232                    manifest.template = Some(template.to_string());
233                }
234                "ADAPTER" => {
235                    manifest.adapter = Some(value.to_string());
236                }
237                "LICENSE" => {
238                    manifest.license = Some(value.to_string());
239                }
240                "MESSAGE" => {
241                    // MESSAGE role content - add to metadata for now
242                    manifest.metadata.insert("message".to_string(), value.to_string());
243                }
244                _ => {
245                    // Unknown directive - store as metadata
246                    manifest.metadata.insert(directive.to_lowercase(), value.to_string());
247                }
248            }
249        }
250
251        if manifest.base_model.is_empty() {
252            return Err(PachaError::Validation("Modelfile must have FROM directive".to_string()));
253        }
254
255        Ok(manifest)
256    }
257
258    /// Load manifest from file
259    pub fn load(path: &Path) -> Result<Self> {
260        let content = std::fs::read_to_string(path).map_err(|e| {
261            PachaError::Io(std::io::Error::new(
262                e.kind(),
263                format!("Failed to read {}: {}", path.display(), e),
264            ))
265        })?;
266        Self::parse(&content)
267    }
268
269    /// Save manifest to file in Modelfile format
270    pub fn save(&self, path: &Path) -> Result<()> {
271        let content = self.to_modelfile();
272        std::fs::write(path, content).map_err(|e| {
273            PachaError::Io(std::io::Error::new(
274                e.kind(),
275                format!("Failed to write {}: {}", path.display(), e),
276            ))
277        })
278    }
279
280    /// Convert to Modelfile format string
281    #[must_use]
282    pub fn to_modelfile(&self) -> String {
283        let mut lines = Vec::new();
284
285        // FROM directive (required)
286        lines.push(format!("FROM {}", self.base_model));
287
288        // SYSTEM directive
289        if let Some(ref system) = self.system_prompt {
290            lines.push(format!("SYSTEM {}", system));
291        }
292
293        // PARAMETER directives
294        if let Some(temp) = self.parameters.temperature {
295            lines.push(format!("PARAMETER temperature {}", temp));
296        }
297        if let Some(top_p) = self.parameters.top_p {
298            lines.push(format!("PARAMETER top_p {}", top_p));
299        }
300        if let Some(top_k) = self.parameters.top_k {
301            lines.push(format!("PARAMETER top_k {}", top_k));
302        }
303        if let Some(max_tokens) = self.parameters.max_tokens {
304            lines.push(format!("PARAMETER max_tokens {}", max_tokens));
305        }
306        for stop in &self.parameters.stop {
307            lines.push(format!("PARAMETER stop \"{}\"", stop));
308        }
309        if let Some(repeat_penalty) = self.parameters.repeat_penalty {
310            lines.push(format!("PARAMETER repeat_penalty {}", repeat_penalty));
311        }
312        if let Some(context_length) = self.parameters.context_length {
313            lines.push(format!("PARAMETER context_length {}", context_length));
314        }
315        if let Some(seed) = self.parameters.seed {
316            lines.push(format!("PARAMETER seed {}", seed));
317        }
318
319        // TEMPLATE directive
320        if let Some(ref template) = self.template {
321            lines.push(format!("TEMPLATE \"{}\"", template));
322        }
323
324        // ADAPTER directive
325        if let Some(ref adapter) = self.adapter {
326            lines.push(format!("ADAPTER {}", adapter));
327        }
328
329        // LICENSE directive
330        if let Some(ref license) = self.license {
331            lines.push(format!("LICENSE {}", license));
332        }
333
334        lines.join("\n")
335    }
336
337    /// Convert to JSON
338    pub fn to_json(&self) -> Result<String> {
339        serde_json::to_string_pretty(self)
340            .map_err(|e| PachaError::Validation(format!("Failed to serialize manifest: {}", e)))
341    }
342
343    /// Parse from JSON
344    pub fn from_json(json: &str) -> Result<Self> {
345        serde_json::from_str(json)
346            .map_err(|e| PachaError::Validation(format!("Failed to parse manifest JSON: {}", e)))
347    }
348}
349
350/// Parse a PARAMETER directive value
351fn parse_parameter(params: &mut ManifestParameters, value: &str) -> Result<()> {
352    let parts: Vec<&str> = value.splitn(2, char::is_whitespace).collect();
353    if parts.len() != 2 {
354        return Err(PachaError::Validation(format!("Invalid PARAMETER syntax: {}", value)));
355    }
356
357    let (name, val) = (parts[0].to_lowercase(), parts[1].trim());
358
359    match name.as_str() {
360        "temperature" => {
361            params.temperature =
362                Some(val.parse().map_err(|_| {
363                    PachaError::Validation(format!("Invalid temperature: {}", val))
364                })?);
365        }
366        "top_p" => {
367            params.top_p = Some(
368                val.parse()
369                    .map_err(|_| PachaError::Validation(format!("Invalid top_p: {}", val)))?,
370            );
371        }
372        "top_k" => {
373            params.top_k = Some(
374                val.parse()
375                    .map_err(|_| PachaError::Validation(format!("Invalid top_k: {}", val)))?,
376            );
377        }
378        "max_tokens" | "num_predict" => {
379            params.max_tokens = Some(
380                val.parse()
381                    .map_err(|_| PachaError::Validation(format!("Invalid max_tokens: {}", val)))?,
382            );
383        }
384        "stop" => {
385            let stop = val.trim_matches('"').trim_matches('\'');
386            params.stop.push(stop.to_string());
387        }
388        "repeat_penalty" => {
389            params.repeat_penalty =
390                Some(val.parse().map_err(|_| {
391                    PachaError::Validation(format!("Invalid repeat_penalty: {}", val))
392                })?);
393        }
394        "repeat_last_n" => {
395            params.repeat_last_n =
396                Some(val.parse().map_err(|_| {
397                    PachaError::Validation(format!("Invalid repeat_last_n: {}", val))
398                })?);
399        }
400        "context_length" | "num_ctx" => {
401            params.context_length =
402                Some(val.parse().map_err(|_| {
403                    PachaError::Validation(format!("Invalid context_length: {}", val))
404                })?);
405        }
406        "seed" => {
407            params.seed = Some(
408                val.parse()
409                    .map_err(|_| PachaError::Validation(format!("Invalid seed: {}", val)))?,
410            );
411        }
412        _ => {
413            // Ignore unknown parameters
414        }
415    }
416
417    Ok(())
418}
419
420// ============================================================================
421// Tests
422// ============================================================================
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    // ========================================================================
429    // Parse Tests
430    // ========================================================================
431
432    #[test]
433    fn test_parse_minimal() {
434        let manifest = ModelManifest::parse("FROM llama3").unwrap();
435        assert_eq!(manifest.base_model, "llama3");
436    }
437
438    #[test]
439    fn test_parse_with_system() {
440        let manifest = ModelManifest::parse(
441            r#"
442            FROM llama3:8b
443            SYSTEM You are a helpful assistant.
444            "#,
445        )
446        .unwrap();
447
448        assert_eq!(manifest.base_model, "llama3:8b");
449        assert_eq!(manifest.system_prompt, Some("You are a helpful assistant.".to_string()));
450    }
451
452    #[test]
453    fn test_parse_with_parameters() {
454        let manifest = ModelManifest::parse(
455            r#"
456            FROM mistral
457            PARAMETER temperature 0.7
458            PARAMETER top_p 0.9
459            PARAMETER top_k 40
460            PARAMETER max_tokens 256
461            "#,
462        )
463        .unwrap();
464
465        assert_eq!(manifest.parameters.temperature, Some(0.7));
466        assert_eq!(manifest.parameters.top_p, Some(0.9));
467        assert_eq!(manifest.parameters.top_k, Some(40));
468        assert_eq!(manifest.parameters.max_tokens, Some(256));
469    }
470
471    #[test]
472    fn test_parse_with_stop_sequences() {
473        let manifest = ModelManifest::parse(
474            r#"
475            FROM llama3
476            PARAMETER stop "<|endoftext|>"
477            PARAMETER stop "User:"
478            "#,
479        )
480        .unwrap();
481
482        assert_eq!(manifest.parameters.stop.len(), 2);
483        assert!(manifest.parameters.stop.contains(&"<|endoftext|>".to_string()));
484        assert!(manifest.parameters.stop.contains(&"User:".to_string()));
485    }
486
487    #[test]
488    fn test_parse_with_template() {
489        let manifest = ModelManifest::parse(
490            r#"
491            FROM llama3
492            TEMPLATE "{{ .System }}\nUser: {{ .Prompt }}\nAssistant:"
493            "#,
494        )
495        .unwrap();
496
497        assert!(manifest.template.is_some());
498        assert!(manifest.template.as_ref().unwrap().contains("System"));
499    }
500
501    #[test]
502    fn test_parse_with_adapter() {
503        let manifest = ModelManifest::parse(
504            r#"
505            FROM llama3:8b
506            ADAPTER /path/to/lora.safetensors
507            "#,
508        )
509        .unwrap();
510
511        assert_eq!(manifest.adapter, Some("/path/to/lora.safetensors".to_string()));
512    }
513
514    #[test]
515    fn test_parse_with_comments() {
516        let manifest = ModelManifest::parse(
517            r#"
518            # This is a comment
519            FROM llama3
520            # Another comment
521            SYSTEM Be helpful
522            "#,
523        )
524        .unwrap();
525
526        assert_eq!(manifest.base_model, "llama3");
527        assert!(manifest.system_prompt.is_some());
528    }
529
530    #[test]
531    fn test_parse_missing_from() {
532        let result = ModelManifest::parse("SYSTEM You are helpful.");
533        assert!(result.is_err());
534    }
535
536    #[test]
537    fn test_parse_empty_from() {
538        let result = ModelManifest::parse("FROM");
539        assert!(result.is_err());
540    }
541
542    // ========================================================================
543    // Builder Tests
544    // ========================================================================
545
546    #[test]
547    fn test_builder() {
548        let manifest = ModelManifest::new("llama3:8b")
549            .with_system("You are a coding assistant.")
550            .with_temperature(0.8)
551            .with_top_p(0.95)
552            .with_max_tokens(1024)
553            .with_stop("<|end|>")
554            .with_description("My custom model");
555
556        assert_eq!(manifest.base_model, "llama3:8b");
557        assert!(manifest.system_prompt.is_some());
558        assert_eq!(manifest.parameters.temperature, Some(0.8));
559        assert_eq!(manifest.parameters.top_p, Some(0.95));
560        assert_eq!(manifest.parameters.max_tokens, Some(1024));
561        assert_eq!(manifest.parameters.stop.len(), 1);
562        assert!(manifest.description.is_some());
563    }
564
565    #[test]
566    fn test_builder_with_metadata() {
567        let manifest = ModelManifest::new("llama3")
568            .with_metadata("author", "test")
569            .with_metadata("version", "1.0");
570
571        assert_eq!(manifest.metadata.get("author"), Some(&"test".to_string()));
572        assert_eq!(manifest.metadata.get("version"), Some(&"1.0".to_string()));
573    }
574
575    // ========================================================================
576    // Serialization Tests
577    // ========================================================================
578
579    #[test]
580    fn test_to_modelfile() {
581        let manifest =
582            ModelManifest::new("llama3:8b").with_system("Be helpful").with_temperature(0.7);
583
584        let modelfile = manifest.to_modelfile();
585        assert!(modelfile.contains("FROM llama3:8b"));
586        assert!(modelfile.contains("SYSTEM Be helpful"));
587        assert!(modelfile.contains("PARAMETER temperature 0.7"));
588    }
589
590    #[test]
591    fn test_roundtrip() {
592        let original = ModelManifest::new("mixtral:8x7b")
593            .with_system("You are an expert.")
594            .with_temperature(0.9)
595            .with_top_k(50)
596            .with_max_tokens(2048);
597
598        let modelfile = original.to_modelfile();
599        let parsed = ModelManifest::parse(&modelfile).unwrap();
600
601        assert_eq!(parsed.base_model, original.base_model);
602        assert_eq!(parsed.system_prompt, original.system_prompt);
603        assert_eq!(parsed.parameters.temperature, original.parameters.temperature);
604        assert_eq!(parsed.parameters.top_k, original.parameters.top_k);
605        assert_eq!(parsed.parameters.max_tokens, original.parameters.max_tokens);
606    }
607
608    #[test]
609    fn test_json_roundtrip() {
610        let original = ModelManifest::new("llama3").with_system("Test").with_temperature(0.5);
611
612        let json = original.to_json().unwrap();
613        let parsed = ModelManifest::from_json(&json).unwrap();
614
615        assert_eq!(parsed.base_model, original.base_model);
616        assert_eq!(parsed.system_prompt, original.system_prompt);
617    }
618
619    // ========================================================================
620    // Parameter Parsing Tests
621    // ========================================================================
622
623    #[test]
624    fn test_parse_context_length_alias() {
625        let manifest = ModelManifest::parse(
626            r#"
627            FROM llama3
628            PARAMETER num_ctx 4096
629            "#,
630        )
631        .unwrap();
632
633        assert_eq!(manifest.parameters.context_length, Some(4096));
634    }
635
636    #[test]
637    fn test_parse_max_tokens_alias() {
638        let manifest = ModelManifest::parse(
639            r#"
640            FROM llama3
641            PARAMETER num_predict 512
642            "#,
643        )
644        .unwrap();
645
646        assert_eq!(manifest.parameters.max_tokens, Some(512));
647    }
648
649    #[test]
650    fn test_parse_repeat_penalty() {
651        let manifest = ModelManifest::parse(
652            r#"
653            FROM llama3
654            PARAMETER repeat_penalty 1.1
655            PARAMETER repeat_last_n 64
656            "#,
657        )
658        .unwrap();
659
660        assert_eq!(manifest.parameters.repeat_penalty, Some(1.1));
661        assert_eq!(manifest.parameters.repeat_last_n, Some(64));
662    }
663
664    #[test]
665    fn test_parse_seed() {
666        let manifest = ModelManifest::parse(
667            r#"
668            FROM llama3
669            PARAMETER seed 42
670            "#,
671        )
672        .unwrap();
673
674        assert_eq!(manifest.parameters.seed, Some(42));
675    }
676
677    #[test]
678    fn test_invalid_parameter_value() {
679        let result = ModelManifest::parse(
680            r#"
681            FROM llama3
682            PARAMETER temperature not_a_number
683            "#,
684        );
685        assert!(result.is_err());
686    }
687
688    // ========================================================================
689    // Default Tests
690    // ========================================================================
691
692    #[test]
693    fn test_default_parameters() {
694        let params = ManifestParameters::default();
695        assert!(params.temperature.is_none());
696        assert!(params.top_p.is_none());
697        assert!(params.stop.is_empty());
698    }
699
700    #[test]
701    fn test_default_manifest() {
702        let manifest = ModelManifest::default();
703        assert!(manifest.base_model.is_empty());
704        assert!(manifest.system_prompt.is_none());
705    }
706}