1use std::{
7 fmt::Display,
8 path::{Path, PathBuf},
9 str::FromStr,
10};
11
12use builder::ModelfileBuilder;
13use derive_more::derive::{AsRef, From};
14use error::ModelfileError;
15use instruction::{Adapter, BaseModel, License, Messages, Parameters, SystemMessage, Template};
16use parser::instructions;
17use serde::{Deserialize, Serialize};
18use strum::{AsRefStr, EnumDiscriminants, EnumIter, EnumString, IntoStaticStr, VariantArray};
19
20use crate::message::Message;
21
22pub mod builder;
23pub mod error;
24pub mod instruction;
25mod parser;
26
27#[cfg(test)]
28pub mod test_data;
29
30const HEADER_COMMENT: &str = "# This file was generated by the Ollama-CLI client\n";
31
32#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
33pub struct Modelfile {
34 pub(crate) from: BaseModel,
35 pub(crate) parameters: Parameters,
36 pub(crate) template: Option<Template>,
37 pub(crate) system: Option<SystemMessage>,
38 pub(crate) adapter: Option<Adapter>,
39 pub(crate) license: Option<License>,
40 pub(crate) messages: Messages,
41}
42
43impl Modelfile {
44 pub fn render(&self) -> String {
45 let mut renderer = Renderer::default();
46 renderer.push_raw(HEADER_COMMENT);
47 renderer.push("FROM", self.from.as_str());
48 renderer.newline();
49
50 renderer.push_opt(
51 "ADAPTER",
52 self.adapter.clone().map(|file| file.to_string()).as_ref(),
53 );
54
55 renderer.push_opt("SYSTEM", self.system.as_ref());
56 renderer.push_opt("TEMPLATE", self.template.as_ref());
57 renderer.push_vec("PARAMETER", self.parameters.as_ref());
58 renderer.push_vec("MESSAGE", &self.messages);
59 renderer.push_opt("LICENSE", self.license.as_ref());
60
61 renderer.finalize()
62 }
63
64 pub fn instructions(self) -> impl Iterator<Item = Instruction> {
65 let Modelfile {
66 from,
67 parameters,
68 template,
69 system,
70 adapter,
71 license,
72 messages,
73 } = self;
74
75 std::iter::once(Instruction::from(from))
76 .chain(parameters.into_iter().map(Instruction::Parameter))
77 .chain(template.into_iter().map(Instruction::Template))
78 .chain(system.into_iter().map(Instruction::System))
79 .chain(adapter.into_iter().map(Instruction::Adapter))
80 .chain(messages.into_iter().map(Instruction::Message))
81 .chain(license.into_iter().map(Instruction::License))
82 }
83
84 pub fn build_on(self) -> ModelfileBuilder {
85 self.into()
86 }
87}
88
89impl From<BaseModel> for Modelfile {
90 fn from(from: BaseModel) -> Self {
91 Modelfile {
92 from,
93 parameters: Default::default(),
94 template: Default::default(),
95 system: Default::default(),
96 adapter: Default::default(),
97 license: Default::default(),
98 messages: Default::default(),
99 }
100 }
101}
102
103impl Display for Modelfile {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 write!(f, "{}", self.render())
106 }
107}
108
109#[derive(AsRef, Debug, Clone, Serialize, Deserialize, PartialEq)]
110#[as_ref(str)]
111pub struct Multiline(String);
112
113impl Multiline {
114 fn extend(&self, more: &str) -> Self {
115 let mut new = self.clone();
116 new.0.push('\n');
117 new.0.push_str(more);
118 new
119 }
120}
121
122impl From<String> for Multiline {
123 fn from(value: String) -> Self {
124 Self(value)
125 }
126}
127
128impl<'a> From<&'a str> for Multiline {
129 fn from(value: &'a str) -> Self {
130 Self(value.to_string())
131 }
132}
133
134impl std::fmt::Display for Multiline {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 write!(f, "\"\"\"{}\"\"\"", self.0)
137 }
138}
139
140#[derive(Clone, Debug, Default)]
141struct Renderer {
142 builder: String,
143}
144
145impl Renderer {
146 fn push_raw(&mut self, s: &str) {
147 self.builder.push_str(s);
148 }
149
150 fn push(&mut self, name: &'static str, s: &str) {
151 self.builder.push_str(name);
152 self.builder.push(' ');
153 self.builder.push_str(s);
154 self.builder.push('\n');
155 }
156
157 fn newline(&mut self) {
158 self.builder.push('\n');
159 }
160
161 fn push_opt<T: ToString>(&mut self, name: &'static str, opt: Option<&T>) {
162 if let Some(t) = opt {
163 self.builder.push_str(name);
164 self.builder.push(' ');
165 self.builder.push_str(&t.to_string());
166 self.builder.push('\n');
167 self.builder.push('\n');
168 }
169 }
170
171 fn push_vec<T: ToString>(&mut self, name: &'static str, vec: impl AsRef<[T]>) {
172 let vec = vec.as_ref();
173 if vec.is_empty() {
174 tracing::debug!("no items passed for {}", name);
175 return;
176 }
177
178 for t in vec {
179 self.builder.push_str(name);
180 self.builder.push(' ');
181 self.builder.push_str(&t.to_string());
182 self.builder.push('\n')
183 }
184
185 self.builder.push('\n')
186 }
187
188 fn finalize(self) -> String {
189 self.builder
190 }
191}
192
193impl FromStr for Modelfile {
194 type Err = ModelfileError;
195
196 fn from_str(input: &str) -> Result<Self, Self::Err> {
197 let instructions: Vec<Instruction> = instructions(input)
198 .map_err(|error| ModelfileError::Parse(error.to_string()))
199 .and_then(|(rest, instructions)| {
200 if rest.is_empty() {
201 Ok(instructions)
202 } else {
203 Err(ModelfileError::Parse(
204 "parser did not consume all input".to_string(),
205 ))
206 }
207 })?;
208
209 instructions.try_into()
210 }
211}
212
213impl TryFrom<Vec<Instruction>> for Modelfile {
214 type Error = ModelfileError;
215
216 fn try_from(instructions: Vec<Instruction>) -> Result<Self, Self::Error> {
217 let builder = ModelfileBuilder::try_from(instructions)?;
218
219 builder.build()
220 }
221}
222
223#[derive(
230 Debug,
231 Clone,
232 From,
233 Serialize,
234 Deserialize,
235 strum::Display,
236 AsRefStr,
237 IntoStaticStr,
238 EnumDiscriminants,
239)]
240#[strum_discriminants(name(InstructionName))]
241#[strum_discriminants(derive(IntoStaticStr, strum::Display, AsRefStr))]
242#[serde(rename_all = "snake_case")]
243pub enum Instruction {
244 Skip,
247 From(BaseModel),
253 Parameter(Parameter),
254 Template(Template),
259 System(SystemMessage),
261 Adapter(Adapter),
262 License(License),
263 Message(Message),
264}
265
266impl From<TensorFile> for Instruction {
267 fn from(value: TensorFile) -> Self {
268 Instruction::Adapter(value.into())
269 }
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
275pub enum TensorFile {
276 Gguf(PathBuf),
277 Safetensor(PathBuf),
278}
279
280impl AsRef<Path> for TensorFile {
281 fn as_ref(&self) -> &Path {
282 match self {
283 TensorFile::Gguf(path_buf) => path_buf.as_ref(),
284 TensorFile::Safetensor(path_buf) => path_buf.as_ref(),
285 }
286 }
287}
288
289impl Display for TensorFile {
290 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291 f.write_str(&self.as_ref().display().to_string())
292 }
293}
294
295#[derive(Debug, Clone, EnumDiscriminants, strum::Display, Serialize, Deserialize, PartialEq)]
300#[strum_discriminants(name(ParameterName))]
301#[strum_discriminants(derive(
302 EnumIter,
303 Hash,
304 PartialOrd,
305 Ord,
306 IntoStaticStr,
307 EnumString,
308 VariantArray,
309 Serialize,
310 Deserialize
311))]
312#[strum_discriminants(strum(serialize_all = "snake_case"))]
313pub enum Parameter {
314 #[strum(to_string = "mirostat {0}")]
317 Mirostat(usize),
318 #[strum(to_string = "mirostat_eta {0}")]
324 MirostatEta(f32),
325 #[strum(to_string = "mirostat_tau {0}")]
329 MirostatTau(f32),
330 #[strum(to_string = "num_ctx {0}")]
334 NumCtx(usize),
335 #[strum(to_string = "repeat_last_n {0}")]
339 RepeatLastN(usize),
340 #[strum(to_string = "repeat_penalty {0}")]
345 RepeatPenalty(f32),
346 #[strum(to_string = "temperature {0}")]
350 Temperature(f32),
351 #[strum(to_string = "seed {0}")]
356 Seed(usize),
357 #[strum(to_string = "stop {0}")]
362 Stop(String),
363 #[strum(to_string = "tfs_z {0}")]
369 TfsZ(f32),
370 #[strum(to_string = "num_predict {0}")]
373 NumPredict(usize),
374 #[strum(to_string = "top_k {0}")]
379 TopK(usize),
380 #[strum(to_string = "top_p {0}")]
385 TopP(f32),
386 #[strum(to_string = "min_p {0}")]
394 MinP(f32),
395}
396
397#[cfg(test)]
398mod tests {
399 use insta::{assert_debug_snapshot, assert_snapshot};
400 use test_data::{load_modelfiles, TestData, TEST_GOOD_DATA_DIR};
401
402 use super::{test_data::TEST_BAD_DATA_DIR, *};
403
404 #[test]
405 fn modelfiles_are_parsed() {
406 let modelfiles: Vec<TestData> = load_modelfiles(TEST_GOOD_DATA_DIR);
407
408 for TestData {
409 path,
410 contents: case,
411 } in modelfiles
412 {
413 dbg!(&path);
414 let modelfile: Modelfile = case
415 .parse::<Modelfile>()
416 .expect("should be able to parse Modelfile");
417
418 dbg!(modelfile);
419 }
420 }
421
422 #[test]
423 fn bad_modelfiles_are_not_parsed() {
424 let modelfiles: Vec<TestData> = load_modelfiles(TEST_BAD_DATA_DIR);
425
426 for TestData {
427 path,
428 contents: case,
429 } in modelfiles
430 {
431 dbg!(&path);
432 let result = case
433 .parse::<Modelfile>()
434 .expect_err("should not be able to parse bad Modelfiles");
435
436 insta::assert_snapshot!(result.to_string(), @"error building Modelfile from parts");
437 }
438 }
439
440 #[test]
441 fn modelfiles_can_be_rendered_as_toml() {
442 let modelfiles: Vec<TestData> = load_modelfiles(TEST_GOOD_DATA_DIR);
443
444 for TestData {
445 path,
446 contents: case,
447 } in modelfiles
448 {
449 dbg!(&path);
450 let modelfile: Modelfile = case
451 .parse::<Modelfile>()
452 .expect("should be able to parse Modelfile");
453
454 let _rendered =
455 toml::to_string(&modelfile).expect("should be able to render Modelfiles as TOML");
456
457 dbg!(modelfile);
458 }
459 }
460
461 #[test]
462 fn snapshot_render() {
463 let modelfile: Modelfile = load_modelfiles(TEST_GOOD_DATA_DIR)
464 .into_iter()
465 .find(|TestData { path, contents: _ }| {
466 path.file_name()
467 .expect("test data should have a valid filename")
468 .to_str()
469 .expect("should be able to convert OsStr to str")
470 == "llama3.2.latest.Modelfile"
471 })
472 .expect("should have at least one test case")
473 .contents
474 .parse()
475 .expect("should be able to parse test data");
476
477 let render = modelfile.render();
478
479 dbg!(&render);
480
481 let _modelfile: Modelfile = render
482 .parse()
483 .expect("should be able to parse rendered Modelfile");
484
485 assert_snapshot!(render);
486 }
487
488 #[test]
489 fn snapshot_parameters() {
490 let param = Parameter::Stop("<eos>".into());
491
492 assert_snapshot!(param, @"stop <eos>");
493 }
494
495 #[test]
496 fn minimal_modelfile_produces_iterator() {
497 let base_model = BaseModel::from("llama8.2");
498
499 let modelfile = Modelfile::from(base_model);
500
501 let instructions: Vec<Instruction> = modelfile.instructions().collect();
502
503 assert!(!instructions.is_empty());
504 assert_eq!(instructions.len(), 1);
505 }
506
507 #[test]
508 fn modelfile_instructions_snapshot() {
509 let test_data: Vec<TestData> = load_modelfiles(TEST_GOOD_DATA_DIR);
510
511 let test_data = test_data
512 .into_iter()
513 .find(|data| {
514 data.path
515 .file_name()
516 .expect("unable to read file name")
517 .to_str()
518 .expect("unable to convert filename to string")
519 == "llama3.1.latest.Modelfile"
520 })
521 .expect("could not load test data");
522
523 let modelfile: Modelfile = test_data
524 .contents
525 .parse()
526 .expect("unable to parse test data");
527
528 let instructions: Vec<Instruction> = modelfile.clone().instructions().collect();
529
530 let instruction_names: Vec<InstructionName> = instructions.iter().map(Into::into).collect();
531
532 assert_debug_snapshot!(instruction_names, @r"
533 [
534 From,
535 Parameter,
536 Parameter,
537 Parameter,
538 Template,
539 License,
540 ]
541 ");
542
543 let result: Modelfile = instructions
544 .try_into()
545 .expect("unable to go from instructions to Modelfile");
546
547 assert_eq!(result, modelfile);
548 }
549}