modelfile/modelfile/
mod.rs

1//! This module defines the [`Modelfile`] structure
2//! that represents a structured format of the [Ollama Modelfile].
3//!
4//! [Ollama Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
5
6use std::{
7    fmt::Display,
8    path::{Path, PathBuf},
9    str::FromStr,
10};
11
12use builder::ModelfileBuilder;
13use derive_more::derive::{AsRef, From};
14use error::ModelfileError;
15use instruction::{Adapter, BaseModel, License, Messages, Parameters, SystemMessage, Template};
16use parser::instructions;
17use serde::{Deserialize, Serialize};
18use strum::{AsRefStr, EnumDiscriminants, EnumIter, EnumString, IntoStaticStr, VariantArray};
19
20use crate::message::Message;
21
22pub mod builder;
23pub mod error;
24pub mod instruction;
25mod parser;
26
27#[cfg(test)]
28pub mod test_data;
29
30const HEADER_COMMENT: &str = "# This file was generated by the Ollama-CLI client\n";
31
32#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
33pub struct Modelfile {
34    pub(crate) from: BaseModel,
35    pub(crate) parameters: Parameters,
36    pub(crate) template: Option<Template>,
37    pub(crate) system: Option<SystemMessage>,
38    pub(crate) adapter: Option<Adapter>,
39    pub(crate) license: Option<License>,
40    pub(crate) messages: Messages,
41}
42
43impl Modelfile {
44    pub fn render(&self) -> String {
45        let mut renderer = Renderer::default();
46        renderer.push_raw(HEADER_COMMENT);
47        renderer.push("FROM", self.from.as_str());
48        renderer.newline();
49
50        renderer.push_opt(
51            "ADAPTER",
52            self.adapter.clone().map(|file| file.to_string()).as_ref(),
53        );
54
55        renderer.push_opt("SYSTEM", self.system.as_ref());
56        renderer.push_opt("TEMPLATE", self.template.as_ref());
57        renderer.push_vec("PARAMETER", self.parameters.as_ref());
58        renderer.push_vec("MESSAGE", &self.messages);
59        renderer.push_opt("LICENSE", self.license.as_ref());
60
61        renderer.finalize()
62    }
63
64    pub fn instructions(self) -> impl Iterator<Item = Instruction> {
65        let Modelfile {
66            from,
67            parameters,
68            template,
69            system,
70            adapter,
71            license,
72            messages,
73        } = self;
74
75        std::iter::once(Instruction::from(from))
76            .chain(parameters.into_iter().map(Instruction::Parameter))
77            .chain(template.into_iter().map(Instruction::Template))
78            .chain(system.into_iter().map(Instruction::System))
79            .chain(adapter.into_iter().map(Instruction::Adapter))
80            .chain(messages.into_iter().map(Instruction::Message))
81            .chain(license.into_iter().map(Instruction::License))
82    }
83
84    pub fn build_on(self) -> ModelfileBuilder {
85        self.into()
86    }
87}
88
89impl From<BaseModel> for Modelfile {
90    fn from(from: BaseModel) -> Self {
91        Modelfile {
92            from,
93            parameters: Default::default(),
94            template: Default::default(),
95            system: Default::default(),
96            adapter: Default::default(),
97            license: Default::default(),
98            messages: Default::default(),
99        }
100    }
101}
102
103impl Display for Modelfile {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        write!(f, "{}", self.render())
106    }
107}
108
109#[derive(AsRef, Debug, Clone, Serialize, Deserialize, PartialEq)]
110#[as_ref(str)]
111pub struct Multiline(String);
112
113impl Multiline {
114    fn extend(&self, more: &str) -> Self {
115        let mut new = self.clone();
116        new.0.push('\n');
117        new.0.push_str(more);
118        new
119    }
120}
121
122impl From<String> for Multiline {
123    fn from(value: String) -> Self {
124        Self(value)
125    }
126}
127
128impl<'a> From<&'a str> for Multiline {
129    fn from(value: &'a str) -> Self {
130        Self(value.to_string())
131    }
132}
133
134impl std::fmt::Display for Multiline {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        write!(f, "\"\"\"{}\"\"\"", self.0)
137    }
138}
139
140#[derive(Clone, Debug, Default)]
141struct Renderer {
142    builder: String,
143}
144
145impl Renderer {
146    fn push_raw(&mut self, s: &str) {
147        self.builder.push_str(s);
148    }
149
150    fn push(&mut self, name: &'static str, s: &str) {
151        self.builder.push_str(name);
152        self.builder.push(' ');
153        self.builder.push_str(s);
154        self.builder.push('\n');
155    }
156
157    fn newline(&mut self) {
158        self.builder.push('\n');
159    }
160
161    fn push_opt<T: ToString>(&mut self, name: &'static str, opt: Option<&T>) {
162        if let Some(t) = opt {
163            self.builder.push_str(name);
164            self.builder.push(' ');
165            self.builder.push_str(&t.to_string());
166            self.builder.push('\n');
167            self.builder.push('\n');
168        }
169    }
170
171    fn push_vec<T: ToString>(&mut self, name: &'static str, vec: impl AsRef<[T]>) {
172        let vec = vec.as_ref();
173        if vec.is_empty() {
174            tracing::debug!("no items passed for {}", name);
175            return;
176        }
177
178        for t in vec {
179            self.builder.push_str(name);
180            self.builder.push(' ');
181            self.builder.push_str(&t.to_string());
182            self.builder.push('\n')
183        }
184
185        self.builder.push('\n')
186    }
187
188    fn finalize(self) -> String {
189        self.builder
190    }
191}
192
193impl FromStr for Modelfile {
194    type Err = ModelfileError;
195
196    fn from_str(input: &str) -> Result<Self, Self::Err> {
197        let instructions: Vec<Instruction> = instructions(input)
198            .map_err(|error| ModelfileError::Parse(error.to_string()))
199            .and_then(|(rest, instructions)| {
200                if rest.is_empty() {
201                    Ok(instructions)
202                } else {
203                    Err(ModelfileError::Parse(
204                        "parser did not consume all input".to_string(),
205                    ))
206                }
207            })?;
208
209        instructions.try_into()
210    }
211}
212
213impl TryFrom<Vec<Instruction>> for Modelfile {
214    type Error = ModelfileError;
215
216    fn try_from(instructions: Vec<Instruction>) -> Result<Self, Self::Error> {
217        let builder = ModelfileBuilder::try_from(instructions)?;
218
219        builder.build()
220    }
221}
222
223/// A single instruction to [Ollama]
224/// that tells [Ollama] how to configure a language model.
225/// Each instruction is defined in the [Modelfile docs]
226///
227/// [Ollama]: https://ollama.com/
228/// [Modelfile docs]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
229#[derive(
230    Debug,
231    Clone,
232    From,
233    Serialize,
234    Deserialize,
235    strum::Display,
236    AsRefStr,
237    IntoStaticStr,
238    EnumDiscriminants,
239)]
240#[strum_discriminants(name(InstructionName))]
241#[strum_discriminants(derive(IntoStaticStr, strum::Display, AsRefStr))]
242#[serde(rename_all = "snake_case")]
243pub enum Instruction {
244    /// Some part of the file that is skipped,
245    /// like an empty line or comment.
246    Skip,
247    /// The model to derive from,
248    /// either a SHA blob,
249    /// GGUF file,
250    /// directory pointing to `safetensor` files,
251    /// or a `<model>:<version>` identifier for an existing model.
252    From(BaseModel),
253    Parameter(Parameter),
254    /// A [golang] [Ollama template].
255    ///
256    /// [golang]: https://pkg.go.dev/text/template
257    /// [Ollama template]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md#template
258    Template(Template),
259    /// The system message for the given model.
260    System(SystemMessage),
261    Adapter(Adapter),
262    License(License),
263    Message(Message),
264}
265
266impl From<TensorFile> for Instruction {
267    fn from(value: TensorFile) -> Self {
268        Instruction::Adapter(value.into())
269    }
270}
271
272/// A file that represents a Tensor.
273/// Either a GGUF or safetensor file.
274#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
275pub enum TensorFile {
276    Gguf(PathBuf),
277    Safetensor(PathBuf),
278}
279
280impl AsRef<Path> for TensorFile {
281    fn as_ref(&self) -> &Path {
282        match self {
283            TensorFile::Gguf(path_buf) => path_buf.as_ref(),
284            TensorFile::Safetensor(path_buf) => path_buf.as_ref(),
285        }
286    }
287}
288
289impl Display for TensorFile {
290    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291        f.write_str(&self.as_ref().display().to_string())
292    }
293}
294
295/// A parameter for the model.
296/// [docs]
297///
298/// [docs]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md#parameter
299#[derive(Debug, Clone, EnumDiscriminants, strum::Display, Serialize, Deserialize, PartialEq)]
300#[strum_discriminants(name(ParameterName))]
301#[strum_discriminants(derive(
302    EnumIter,
303    Hash,
304    PartialOrd,
305    Ord,
306    IntoStaticStr,
307    EnumString,
308    VariantArray,
309    Serialize,
310    Deserialize
311))]
312#[strum_discriminants(strum(serialize_all = "snake_case"))]
313pub enum Parameter {
314    /// Enable Mirostat sampling for controlling perplexity.
315    /// (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)
316    #[strum(to_string = "mirostat {0}")]
317    Mirostat(usize),
318    /// Influences how quickly the algorithm responds
319    /// to feedback from the generated text.
320    /// A lower learning rate will result in slower adjustments,
321    /// while a higher learning rate will make the algorithm more responsive.
322    /// (Default: 0.1)
323    #[strum(to_string = "mirostat_eta {0}")]
324    MirostatEta(f32),
325    /// Controls the balance between coherence and diversity of the output.
326    /// A lower value will result in more focused and coherent text.
327    /// (Default: 5.0)
328    #[strum(to_string = "mirostat_tau {0}")]
329    MirostatTau(f32),
330    /// Sets the size of the context window
331    /// used to generate the next token.
332    /// (Default: 2048)
333    #[strum(to_string = "num_ctx {0}")]
334    NumCtx(usize),
335    /// Sets how far back for the model
336    /// to look back to prevent repetition.
337    /// (Default: 64, 0 = disabled, -1 = num_ctx)
338    #[strum(to_string = "repeat_last_n {0}")]
339    RepeatLastN(usize),
340    /// Sets how strongly to penalize repetitions.
341    /// A higher value (e.g., 1.5) will penalize repetitions more strongly,
342    /// while a lower value (e.g., 0.9) will be more lenient.
343    /// (Default: 1.1)
344    #[strum(to_string = "repeat_penalty {0}")]
345    RepeatPenalty(f32),
346    /// The temperature of the model.
347    /// Increasing the temperature will make the model answer more creatively.
348    /// (Default: 0.8)
349    #[strum(to_string = "temperature {0}")]
350    Temperature(f32),
351    /// Sets the random number seed to use for generation.
352    /// Setting this to a specific number will make the model generate the same text
353    /// for the same prompt.
354    /// (Default: 0)
355    #[strum(to_string = "seed {0}")]
356    Seed(usize),
357    /// Sets the stop sequences to use.
358    /// When this pattern is encountered the LLM will stop generating text and return.
359    /// Multiple stop patterns may be set by specifying multiple separate stop parameters
360    /// in a modelfile.
361    #[strum(to_string = "stop {0}")]
362    Stop(String),
363    /// Tail free sampling is used to reduce the impact
364    /// of less probable tokens from the output.
365    /// A higher value (e.g., 2.0) will reduce the impact more,
366    /// while a value of 1.0 disables this setting.
367    /// (default: 1)
368    #[strum(to_string = "tfs_z {0}")]
369    TfsZ(f32),
370    /// Maximum number of tokens to predict when generating text.
371    /// (Default: 128, -1 = infinite generation, -2 = fill context)
372    #[strum(to_string = "num_predict {0}")]
373    NumPredict(usize),
374    /// Reduces the probability of generating nonsense.
375    /// A higher value (e.g. 100) will give more diverse answers,
376    /// while a lower value (e.g. 10) will be more conservative.
377    /// (Default: 40)
378    #[strum(to_string = "top_k {0}")]
379    TopK(usize),
380    /// Works together with top-k.
381    /// A higher value (e.g., 0.95) will lead to more diverse text,
382    /// while a lower value (e.g., 0.5) will generate more focused and conservative text.
383    /// (Default: 0.9)
384    #[strum(to_string = "top_p {0}")]
385    TopP(f32),
386    /// Alternative to the top_p,
387    /// and aims to ensure a balance of quality and variety.
388    /// The parameter p represents the minimum probability for a token to be considered,
389    /// relative to the probability of the most likely token.
390    /// For example, with p=0.05 and the most likely token having a probability of 0.9,
391    /// logits with a value less than 0.045 are filtered out.
392    /// (Default: 0.0)
393    #[strum(to_string = "min_p {0}")]
394    MinP(f32),
395}
396
397#[cfg(test)]
398mod tests {
399    use insta::{assert_debug_snapshot, assert_snapshot};
400    use test_data::{load_modelfiles, TestData, TEST_GOOD_DATA_DIR};
401
402    use super::{test_data::TEST_BAD_DATA_DIR, *};
403
404    #[test]
405    fn modelfiles_are_parsed() {
406        let modelfiles: Vec<TestData> = load_modelfiles(TEST_GOOD_DATA_DIR);
407
408        for TestData {
409            path,
410            contents: case,
411        } in modelfiles
412        {
413            dbg!(&path);
414            let modelfile: Modelfile = case
415                .parse::<Modelfile>()
416                .expect("should be able to parse Modelfile");
417
418            dbg!(modelfile);
419        }
420    }
421
422    #[test]
423    fn bad_modelfiles_are_not_parsed() {
424        let modelfiles: Vec<TestData> = load_modelfiles(TEST_BAD_DATA_DIR);
425
426        for TestData {
427            path,
428            contents: case,
429        } in modelfiles
430        {
431            dbg!(&path);
432            let result = case
433                .parse::<Modelfile>()
434                .expect_err("should not be able to parse bad Modelfiles");
435
436            insta::assert_snapshot!(result.to_string(), @"error building Modelfile from parts");
437        }
438    }
439
440    #[test]
441    fn modelfiles_can_be_rendered_as_toml() {
442        let modelfiles: Vec<TestData> = load_modelfiles(TEST_GOOD_DATA_DIR);
443
444        for TestData {
445            path,
446            contents: case,
447        } in modelfiles
448        {
449            dbg!(&path);
450            let modelfile: Modelfile = case
451                .parse::<Modelfile>()
452                .expect("should be able to parse Modelfile");
453
454            let _rendered =
455                toml::to_string(&modelfile).expect("should be able to render Modelfiles as TOML");
456
457            dbg!(modelfile);
458        }
459    }
460
461    #[test]
462    fn snapshot_render() {
463        let modelfile: Modelfile = load_modelfiles(TEST_GOOD_DATA_DIR)
464            .into_iter()
465            .find(|TestData { path, contents: _ }| {
466                path.file_name()
467                    .expect("test data should have a valid filename")
468                    .to_str()
469                    .expect("should be able to convert OsStr to str")
470                    == "llama3.2.latest.Modelfile"
471            })
472            .expect("should have at least one test case")
473            .contents
474            .parse()
475            .expect("should be able to parse test data");
476
477        let render = modelfile.render();
478
479        dbg!(&render);
480
481        let _modelfile: Modelfile = render
482            .parse()
483            .expect("should be able to parse rendered Modelfile");
484
485        assert_snapshot!(render);
486    }
487
488    #[test]
489    fn snapshot_parameters() {
490        let param = Parameter::Stop("<eos>".into());
491
492        assert_snapshot!(param, @"stop <eos>");
493    }
494
495    #[test]
496    fn minimal_modelfile_produces_iterator() {
497        let base_model = BaseModel::from("llama8.2");
498
499        let modelfile = Modelfile::from(base_model);
500
501        let instructions: Vec<Instruction> = modelfile.instructions().collect();
502
503        assert!(!instructions.is_empty());
504        assert_eq!(instructions.len(), 1);
505    }
506
507    #[test]
508    fn modelfile_instructions_snapshot() {
509        let test_data: Vec<TestData> = load_modelfiles(TEST_GOOD_DATA_DIR);
510
511        let test_data = test_data
512            .into_iter()
513            .find(|data| {
514                data.path
515                    .file_name()
516                    .expect("unable to read file name")
517                    .to_str()
518                    .expect("unable to convert filename to string")
519                    == "llama3.1.latest.Modelfile"
520            })
521            .expect("could not load test data");
522
523        let modelfile: Modelfile = test_data
524            .contents
525            .parse()
526            .expect("unable to parse test data");
527
528        let instructions: Vec<Instruction> = modelfile.clone().instructions().collect();
529
530        let instruction_names: Vec<InstructionName> = instructions.iter().map(Into::into).collect();
531
532        assert_debug_snapshot!(instruction_names, @r"
533        [
534            From,
535            Parameter,
536            Parameter,
537            Parameter,
538            Template,
539            License,
540        ]
541        ");
542
543        let result: Modelfile = instructions
544            .try_into()
545            .expect("unable to go from instructions to Modelfile");
546
547        assert_eq!(result, modelfile);
548    }
549}