Skip to main content

pacha/
manifest.rs

1//! Model Manifest (Modelfile) Support
2//!
3//! Provides a Modelfile-like configuration for custom model definitions,
4//! similar to ollama's Modelfile format.
5//!
6//! ## Modelfile Format
7//!
8//! ```text
9//! FROM llama3:8b
10//! SYSTEM You are a helpful coding assistant.
11//! PARAMETER temperature 0.7
12//! PARAMETER top_p 0.9
13//! PARAMETER stop "<|endoftext|>"
14//! TEMPLATE "{{ .System }}\nUser: {{ .Prompt }}\nAssistant:"
15//! ```
16//!
17//! ## Example
18//!
19//! ```rust,ignore
20//! use pacha::manifest::ModelManifest;
21//!
22//! let manifest = ModelManifest::parse(r#"
23//!     FROM llama3:8b
24//!     SYSTEM You are helpful.
25//!     PARAMETER temperature 0.7
26//! "#)?;
27//!
28//! println!("Base: {}", manifest.base_model);
29//! println!("System: {:?}", manifest.system_prompt);
30//! ```
31
32use serde::{Deserialize, Serialize};
33use std::collections::HashMap;
34use std::path::Path;
35
36use crate::error::{PachaError, Result};
37
38// ============================================================================
39// MANIFEST-001: Model Manifest
40// ============================================================================
41
42/// Model manifest defining a custom model configuration
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ModelManifest {
45    /// Base model reference (FROM directive)
46    pub base_model: String,
47    /// System prompt (SYSTEM directive)
48    pub system_prompt: Option<String>,
49    /// Generation parameters (PARAMETER directives)
50    pub parameters: ManifestParameters,
51    /// Custom prompt template (TEMPLATE directive)
52    pub template: Option<String>,
53    /// Model adapter/LoRA path (ADAPTER directive)
54    pub adapter: Option<String>,
55    /// License information (LICENSE directive)
56    pub license: Option<String>,
57    /// Model description
58    pub description: Option<String>,
59    /// Custom metadata
60    pub metadata: HashMap<String, String>,
61}
62
63impl Default for ModelManifest {
64    fn default() -> Self {
65        Self {
66            base_model: String::new(),
67            system_prompt: None,
68            parameters: ManifestParameters::default(),
69            template: None,
70            adapter: None,
71            license: None,
72            description: None,
73            metadata: HashMap::new(),
74        }
75    }
76}
77
78/// Generation parameters from manifest
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ManifestParameters {
81    /// Sampling temperature
82    pub temperature: Option<f32>,
83    /// Top-p (nucleus) sampling
84    pub top_p: Option<f32>,
85    /// Top-k sampling
86    pub top_k: Option<usize>,
87    /// Maximum tokens to generate
88    pub max_tokens: Option<usize>,
89    /// Stop sequences
90    pub stop: Vec<String>,
91    /// Repetition penalty
92    pub repeat_penalty: Option<f32>,
93    /// Number of tokens to consider for repetition penalty
94    pub repeat_last_n: Option<usize>,
95    /// Context window size
96    pub context_length: Option<usize>,
97    /// Seed for reproducibility
98    pub seed: Option<u64>,
99}
100
101impl Default for ManifestParameters {
102    fn default() -> Self {
103        Self {
104            temperature: None,
105            top_p: None,
106            top_k: None,
107            max_tokens: None,
108            stop: Vec::new(),
109            repeat_penalty: None,
110            repeat_last_n: None,
111            context_length: None,
112            seed: None,
113        }
114    }
115}
116
117impl ModelManifest {
118    /// Create a new manifest with a base model
119    #[must_use]
120    pub fn new(base_model: impl Into<String>) -> Self {
121        Self { base_model: base_model.into(), ..Default::default() }
122    }
123
124    /// Set system prompt
125    #[must_use]
126    pub fn with_system(mut self, prompt: impl Into<String>) -> Self {
127        self.system_prompt = Some(prompt.into());
128        self
129    }
130
131    /// Set temperature
132    #[must_use]
133    pub fn with_temperature(mut self, temp: f32) -> Self {
134        self.parameters.temperature = Some(temp);
135        self
136    }
137
138    /// Set top_p
139    #[must_use]
140    pub fn with_top_p(mut self, top_p: f32) -> Self {
141        self.parameters.top_p = Some(top_p);
142        self
143    }
144
145    /// Set top_k
146    #[must_use]
147    pub fn with_top_k(mut self, top_k: usize) -> Self {
148        self.parameters.top_k = Some(top_k);
149        self
150    }
151
152    /// Set max tokens
153    #[must_use]
154    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
155        self.parameters.max_tokens = Some(max_tokens);
156        self
157    }
158
159    /// Add stop sequence
160    #[must_use]
161    pub fn with_stop(mut self, stop: impl Into<String>) -> Self {
162        self.parameters.stop.push(stop.into());
163        self
164    }
165
166    /// Set template
167    #[must_use]
168    pub fn with_template(mut self, template: impl Into<String>) -> Self {
169        self.template = Some(template.into());
170        self
171    }
172
173    /// Set adapter path
174    #[must_use]
175    pub fn with_adapter(mut self, adapter: impl Into<String>) -> Self {
176        self.adapter = Some(adapter.into());
177        self
178    }
179
180    /// Set description
181    #[must_use]
182    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
183        self.description = Some(desc.into());
184        self
185    }
186
187    /// Add metadata
188    #[must_use]
189    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
190        self.metadata.insert(key.into(), value.into());
191        self
192    }
193
194    /// Parse a Modelfile-format string
195    pub fn parse(content: &str) -> Result<Self> {
196        let mut manifest = Self::default();
197
198        for line in content.lines() {
199            let line = line.trim();
200
201            // Skip empty lines and comments
202            if line.is_empty() || line.starts_with('#') {
203                continue;
204            }
205
206            // Parse directive
207            let (directive, value) = if let Some(idx) = line.find(char::is_whitespace) {
208                let (d, v) = line.split_at(idx);
209                (d.to_uppercase(), v.trim())
210            } else {
211                (line.to_uppercase(), "")
212            };
213
214            match directive.as_str() {
215                "FROM" => {
216                    if value.is_empty() {
217                        return Err(PachaError::Validation(
218                            "FROM requires a model reference".to_string(),
219                        ));
220                    }
221                    manifest.base_model = value.to_string();
222                }
223                "SYSTEM" => {
224                    manifest.system_prompt = Some(value.to_string());
225                }
226                "PARAMETER" => {
227                    parse_parameter(&mut manifest.parameters, value)?;
228                }
229                "TEMPLATE" => {
230                    // Template can be multi-line with quotes
231                    let template = value.trim_matches('"').trim_matches('\'');
232                    manifest.template = Some(template.to_string());
233                }
234                "ADAPTER" => {
235                    manifest.adapter = Some(value.to_string());
236                }
237                "LICENSE" => {
238                    manifest.license = Some(value.to_string());
239                }
240                "MESSAGE" => {
241                    // MESSAGE role content - add to metadata for now
242                    manifest.metadata.insert("message".to_string(), value.to_string());
243                }
244                _ => {
245                    // Unknown directive - store as metadata
246                    manifest.metadata.insert(directive.to_lowercase(), value.to_string());
247                }
248            }
249        }
250
251        if manifest.base_model.is_empty() {
252            return Err(PachaError::Validation("Modelfile must have FROM directive".to_string()));
253        }
254
255        Ok(manifest)
256    }
257
258    /// Load manifest from file
259    pub fn load(path: &Path) -> Result<Self> {
260        let content = std::fs::read_to_string(path).map_err(|e| {
261            PachaError::Io(std::io::Error::new(
262                e.kind(),
263                format!("Failed to read {}: {}", path.display(), e),
264            ))
265        })?;
266        Self::parse(&content)
267    }
268
269    /// Save manifest to file in Modelfile format
270    pub fn save(&self, path: &Path) -> Result<()> {
271        let content = self.to_modelfile();
272        std::fs::write(path, content).map_err(|e| {
273            PachaError::Io(std::io::Error::new(
274                e.kind(),
275                format!("Failed to write {}: {}", path.display(), e),
276            ))
277        })
278    }
279
280    /// Convert to Modelfile format string
281    #[must_use]
282    pub fn to_modelfile(&self) -> String {
283        let mut lines = Vec::new();
284
285        // FROM directive (required)
286        lines.push(format!("FROM {}", self.base_model));
287
288        // SYSTEM directive
289        if let Some(ref system) = self.system_prompt {
290            lines.push(format!("SYSTEM {}", system));
291        }
292
293        // PARAMETER directives
294        if let Some(temp) = self.parameters.temperature {
295            lines.push(format!("PARAMETER temperature {}", temp));
296        }
297        if let Some(top_p) = self.parameters.top_p {
298            lines.push(format!("PARAMETER top_p {}", top_p));
299        }
300        if let Some(top_k) = self.parameters.top_k {
301            lines.push(format!("PARAMETER top_k {}", top_k));
302        }
303        if let Some(max_tokens) = self.parameters.max_tokens {
304            lines.push(format!("PARAMETER max_tokens {}", max_tokens));
305        }
306        for stop in &self.parameters.stop {
307            lines.push(format!("PARAMETER stop \"{}\"", stop));
308        }
309        if let Some(repeat_penalty) = self.parameters.repeat_penalty {
310            lines.push(format!("PARAMETER repeat_penalty {}", repeat_penalty));
311        }
312        if let Some(context_length) = self.parameters.context_length {
313            lines.push(format!("PARAMETER context_length {}", context_length));
314        }
315        if let Some(seed) = self.parameters.seed {
316            lines.push(format!("PARAMETER seed {}", seed));
317        }
318
319        // TEMPLATE directive
320        if let Some(ref template) = self.template {
321            lines.push(format!("TEMPLATE \"{}\"", template));
322        }
323
324        // ADAPTER directive
325        if let Some(ref adapter) = self.adapter {
326            lines.push(format!("ADAPTER {}", adapter));
327        }
328
329        // LICENSE directive
330        if let Some(ref license) = self.license {
331            lines.push(format!("LICENSE {}", license));
332        }
333
334        lines.join("\n")
335    }
336
337    /// Convert to JSON
338    pub fn to_json(&self) -> Result<String> {
339        serde_json::to_string_pretty(self)
340            .map_err(|e| PachaError::Validation(format!("Failed to serialize manifest: {}", e)))
341    }
342
343    /// Parse from JSON
344    pub fn from_json(json: &str) -> Result<Self> {
345        serde_json::from_str(json)
346            .map_err(|e| PachaError::Validation(format!("Failed to parse manifest JSON: {}", e)))
347    }
348}
349
350/// Parse a PARAMETER directive value
351fn parse_parameter(params: &mut ManifestParameters, value: &str) -> Result<()> {
352    let parts: Vec<&str> = value.splitn(2, char::is_whitespace).collect();
353    if parts.len() != 2 {
354        return Err(PachaError::Validation(format!("Invalid PARAMETER syntax: {}", value)));
355    }
356    let (name, val) = (parts[0].to_lowercase(), parts[1].trim());
357    apply_parameter(params, &name, val)
358}
359
360fn apply_parameter(params: &mut ManifestParameters, name: &str, val: &str) -> Result<()> {
361    match name {
362        "temperature" => params.temperature = Some(parse_named(val, "temperature")?),
363        "top_p" => params.top_p = Some(parse_named(val, "top_p")?),
364        "top_k" => params.top_k = Some(parse_named(val, "top_k")?),
365        "max_tokens" | "num_predict" => params.max_tokens = Some(parse_named(val, "max_tokens")?),
366        "stop" => {
367            let stop = val.trim_matches('"').trim_matches('\'');
368            params.stop.push(stop.to_string());
369        }
370        "repeat_penalty" => params.repeat_penalty = Some(parse_named(val, "repeat_penalty")?),
371        "repeat_last_n" => params.repeat_last_n = Some(parse_named(val, "repeat_last_n")?),
372        "context_length" | "num_ctx" => {
373            params.context_length = Some(parse_named(val, "context_length")?);
374        }
375        "seed" => params.seed = Some(parse_named(val, "seed")?),
376        _ => {
377            // Ignore unknown parameters
378        }
379    }
380    Ok(())
381}
382
383fn parse_named<T: std::str::FromStr>(val: &str, field: &str) -> Result<T> {
384    val.parse().map_err(|_| PachaError::Validation(format!("Invalid {field}: {val}")))
385}
386
387// ============================================================================
388// Tests
389// ============================================================================
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    // ========================================================================
396    // Parse Tests
397    // ========================================================================
398
399    #[test]
400    fn test_parse_minimal() {
401        let manifest = ModelManifest::parse("FROM llama3").unwrap();
402        assert_eq!(manifest.base_model, "llama3");
403    }
404
405    #[test]
406    fn test_parse_with_system() {
407        let manifest = ModelManifest::parse(
408            r#"
409            FROM llama3:8b
410            SYSTEM You are a helpful assistant.
411            "#,
412        )
413        .unwrap();
414
415        assert_eq!(manifest.base_model, "llama3:8b");
416        assert_eq!(manifest.system_prompt, Some("You are a helpful assistant.".to_string()));
417    }
418
419    #[test]
420    fn test_parse_with_parameters() {
421        let manifest = ModelManifest::parse(
422            r#"
423            FROM mistral
424            PARAMETER temperature 0.7
425            PARAMETER top_p 0.9
426            PARAMETER top_k 40
427            PARAMETER max_tokens 256
428            "#,
429        )
430        .unwrap();
431
432        assert_eq!(manifest.parameters.temperature, Some(0.7));
433        assert_eq!(manifest.parameters.top_p, Some(0.9));
434        assert_eq!(manifest.parameters.top_k, Some(40));
435        assert_eq!(manifest.parameters.max_tokens, Some(256));
436    }
437
438    #[test]
439    fn test_parse_with_stop_sequences() {
440        let manifest = ModelManifest::parse(
441            r#"
442            FROM llama3
443            PARAMETER stop "<|endoftext|>"
444            PARAMETER stop "User:"
445            "#,
446        )
447        .unwrap();
448
449        assert_eq!(manifest.parameters.stop.len(), 2);
450        assert!(manifest.parameters.stop.contains(&"<|endoftext|>".to_string()));
451        assert!(manifest.parameters.stop.contains(&"User:".to_string()));
452    }
453
454    #[test]
455    fn test_parse_with_template() {
456        let manifest = ModelManifest::parse(
457            r#"
458            FROM llama3
459            TEMPLATE "{{ .System }}\nUser: {{ .Prompt }}\nAssistant:"
460            "#,
461        )
462        .unwrap();
463
464        assert!(manifest.template.is_some());
465        assert!(manifest.template.as_ref().unwrap().contains("System"));
466    }
467
468    #[test]
469    fn test_parse_with_adapter() {
470        let manifest = ModelManifest::parse(
471            r#"
472            FROM llama3:8b
473            ADAPTER /path/to/lora.safetensors
474            "#,
475        )
476        .unwrap();
477
478        assert_eq!(manifest.adapter, Some("/path/to/lora.safetensors".to_string()));
479    }
480
481    #[test]
482    fn test_parse_with_comments() {
483        let manifest = ModelManifest::parse(
484            r#"
485            # This is a comment
486            FROM llama3
487            # Another comment
488            SYSTEM Be helpful
489            "#,
490        )
491        .unwrap();
492
493        assert_eq!(manifest.base_model, "llama3");
494        assert!(manifest.system_prompt.is_some());
495    }
496
497    #[test]
498    fn test_parse_missing_from() {
499        let result = ModelManifest::parse("SYSTEM You are helpful.");
500        assert!(result.is_err());
501    }
502
503    #[test]
504    fn test_parse_empty_from() {
505        let result = ModelManifest::parse("FROM");
506        assert!(result.is_err());
507    }
508
509    // ========================================================================
510    // Builder Tests
511    // ========================================================================
512
513    #[test]
514    fn test_builder() {
515        let manifest = ModelManifest::new("llama3:8b")
516            .with_system("You are a coding assistant.")
517            .with_temperature(0.8)
518            .with_top_p(0.95)
519            .with_max_tokens(1024)
520            .with_stop("<|end|>")
521            .with_description("My custom model");
522
523        assert_eq!(manifest.base_model, "llama3:8b");
524        assert!(manifest.system_prompt.is_some());
525        assert_eq!(manifest.parameters.temperature, Some(0.8));
526        assert_eq!(manifest.parameters.top_p, Some(0.95));
527        assert_eq!(manifest.parameters.max_tokens, Some(1024));
528        assert_eq!(manifest.parameters.stop.len(), 1);
529        assert!(manifest.description.is_some());
530    }
531
532    #[test]
533    fn test_builder_with_metadata() {
534        let manifest = ModelManifest::new("llama3")
535            .with_metadata("author", "test")
536            .with_metadata("version", "1.0");
537
538        assert_eq!(manifest.metadata.get("author"), Some(&"test".to_string()));
539        assert_eq!(manifest.metadata.get("version"), Some(&"1.0".to_string()));
540    }
541
542    // ========================================================================
543    // Serialization Tests
544    // ========================================================================
545
546    #[test]
547    fn test_to_modelfile() {
548        let manifest =
549            ModelManifest::new("llama3:8b").with_system("Be helpful").with_temperature(0.7);
550
551        let modelfile = manifest.to_modelfile();
552        assert!(modelfile.contains("FROM llama3:8b"));
553        assert!(modelfile.contains("SYSTEM Be helpful"));
554        assert!(modelfile.contains("PARAMETER temperature 0.7"));
555    }
556
557    #[test]
558    fn test_roundtrip() {
559        let original = ModelManifest::new("mixtral:8x7b")
560            .with_system("You are an expert.")
561            .with_temperature(0.9)
562            .with_top_k(50)
563            .with_max_tokens(2048);
564
565        let modelfile = original.to_modelfile();
566        let parsed = ModelManifest::parse(&modelfile).unwrap();
567
568        assert_eq!(parsed.base_model, original.base_model);
569        assert_eq!(parsed.system_prompt, original.system_prompt);
570        assert_eq!(parsed.parameters.temperature, original.parameters.temperature);
571        assert_eq!(parsed.parameters.top_k, original.parameters.top_k);
572        assert_eq!(parsed.parameters.max_tokens, original.parameters.max_tokens);
573    }
574
575    #[test]
576    fn test_json_roundtrip() {
577        let original = ModelManifest::new("llama3").with_system("Test").with_temperature(0.5);
578
579        let json = original.to_json().unwrap();
580        let parsed = ModelManifest::from_json(&json).unwrap();
581
582        assert_eq!(parsed.base_model, original.base_model);
583        assert_eq!(parsed.system_prompt, original.system_prompt);
584    }
585
586    // ========================================================================
587    // Parameter Parsing Tests
588    // ========================================================================
589
590    #[test]
591    fn test_parse_context_length_alias() {
592        let manifest = ModelManifest::parse(
593            r#"
594            FROM llama3
595            PARAMETER num_ctx 4096
596            "#,
597        )
598        .unwrap();
599
600        assert_eq!(manifest.parameters.context_length, Some(4096));
601    }
602
603    #[test]
604    fn test_parse_max_tokens_alias() {
605        let manifest = ModelManifest::parse(
606            r#"
607            FROM llama3
608            PARAMETER num_predict 512
609            "#,
610        )
611        .unwrap();
612
613        assert_eq!(manifest.parameters.max_tokens, Some(512));
614    }
615
616    #[test]
617    fn test_parse_repeat_penalty() {
618        let manifest = ModelManifest::parse(
619            r#"
620            FROM llama3
621            PARAMETER repeat_penalty 1.1
622            PARAMETER repeat_last_n 64
623            "#,
624        )
625        .unwrap();
626
627        assert_eq!(manifest.parameters.repeat_penalty, Some(1.1));
628        assert_eq!(manifest.parameters.repeat_last_n, Some(64));
629    }
630
631    #[test]
632    fn test_parse_seed() {
633        let manifest = ModelManifest::parse(
634            r#"
635            FROM llama3
636            PARAMETER seed 42
637            "#,
638        )
639        .unwrap();
640
641        assert_eq!(manifest.parameters.seed, Some(42));
642    }
643
644    #[test]
645    fn test_invalid_parameter_value() {
646        let result = ModelManifest::parse(
647            r#"
648            FROM llama3
649            PARAMETER temperature not_a_number
650            "#,
651        );
652        assert!(result.is_err());
653    }
654
655    // ========================================================================
656    // Default Tests
657    // ========================================================================
658
659    #[test]
660    fn test_default_parameters() {
661        let params = ManifestParameters::default();
662        assert!(params.temperature.is_none());
663        assert!(params.top_p.is_none());
664        assert!(params.stop.is_empty());
665    }
666
667    #[test]
668    fn test_default_manifest() {
669        let manifest = ModelManifest::default();
670        assert!(manifest.base_model.is_empty());
671        assert!(manifest.system_prompt.is_none());
672    }
673}