ollama_rs/models/
create.rs

1use serde::{Deserialize, Serialize};
2
3use crate::{error::OllamaError, generation::chat::ChatMessage, Ollama};
4
5use super::ModelOptions;
6
7/// A stream of `CreateModelStatus` objects
8#[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
9#[cfg(feature = "stream")]
10pub type CreateModelStatusStream = std::pin::Pin<
11    Box<dyn tokio_stream::Stream<Item = crate::error::Result<CreateModelStatus>> + Send>,
12>;
13
14impl Ollama {
15    #[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
16    #[cfg(feature = "stream")]
17    /// Create a model with streaming, meaning that each new status will be streamed.
18    pub async fn create_model_stream(
19        &self,
20        mut request: CreateModelRequest,
21    ) -> crate::error::Result<CreateModelStatusStream> {
22        use tokio_stream::StreamExt;
23
24        use crate::error::OllamaError;
25
26        request.stream = true;
27
28        let url = format!("{}api/create", self.url_str());
29        let builder = self.reqwest_client.post(url);
30
31        #[cfg(feature = "headers")]
32        let builder = builder.headers(self.request_headers.clone());
33
34        let res = builder.json(&request).send().await?;
35
36        if !res.status().is_success() {
37            return Err(OllamaError::Other(res.text().await?));
38        }
39
40        let stream = Box::new(res.bytes_stream().map(|res| match res {
41            Ok(bytes) => {
42                let res = serde_json::from_slice::<CreateModelStatus>(&bytes);
43                match res {
44                    Ok(res) => Ok(res),
45                    Err(e) => {
46                        let err =
47                            serde_json::from_slice::<crate::error::InternalOllamaError>(&bytes);
48                        match err {
49                            Ok(err) => Err(OllamaError::InternalError(err)),
50                            Err(_) => Err(OllamaError::from(e)),
51                        }
52                    }
53                }
54            }
55            Err(e) => Err(OllamaError::Other(format!(
56                "Failed to read response: {}",
57                e
58            ))),
59        }));
60
61        Ok(std::pin::Pin::from(stream))
62    }
63
64    /// Create a model with a single response, only the final status will be returned.
65    pub async fn create_model(
66        &self,
67        request: CreateModelRequest,
68    ) -> crate::error::Result<CreateModelStatus> {
69        let url = format!("{}api/create", self.url_str());
70        let builder = self.reqwest_client.post(url);
71
72        #[cfg(feature = "headers")]
73        let builder = builder.headers(self.request_headers.clone());
74
75        let res = builder.json(&request).send().await?;
76
77        if !res.status().is_success() {
78            return Err(OllamaError::Other(res.text().await?));
79        }
80
81        let res = res.bytes().await?;
82        let res = serde_json::from_slice::<CreateModelStatus>(&res)?;
83
84        Ok(res)
85    }
86}
87
88#[derive(Serialize)]
89pub enum QuantizationType {
90    #[serde(rename = "q2_K")]
91    Q2K,
92    #[serde(rename = "q3_K_L")]
93    Q3KL,
94    #[serde(rename = "q3_K_M")]
95    Q3KM,
96    #[serde(rename = "q3_K_S")]
97    Q3KS,
98    #[serde(rename = "q4_0")]
99    Q40,
100    #[serde(rename = "q4_1")]
101    Q41,
102    #[serde(rename = "q4_K_M")]
103    Q4KM,
104    #[serde(rename = "q4_K_S")]
105    Q4KS,
106    #[serde(rename = "q5_0")]
107    Q50,
108    #[serde(rename = "q5_1")]
109    Q51,
110    #[serde(rename = "q5_K_M")]
111    Q5KM,
112    #[serde(rename = "q5_K_S")]
113    Q5KS,
114    #[serde(rename = "q6_K")]
115    Q6K,
116    #[serde(rename = "q8_0")]
117    Q80,
118}
119
120/// A create model request to Ollama.
121#[derive(Serialize)]
122pub struct CreateModelRequest {
123    /// Name of the model to create
124    #[serde(rename = "model")]
125    model_name: String,
126    /// Name of an existing model to create the new model from
127    #[serde(rename = "from")]
128    from_model: Option<String>,
129    /// A dictionary of file names to SHA256 digests of blobs to create the model from
130    #[serde(skip_serializing_if = "Option::is_none")]
131    files: Option<std::collections::HashMap<String, String>>,
132    /// A dictionary of file names to SHA256 digests of blobs for LORA adapters
133    #[serde(skip_serializing_if = "Option::is_none")]
134    adapters: Option<std::collections::HashMap<String, String>>,
135    /// The prompt template for the model
136    #[serde(skip_serializing_if = "Option::is_none")]
137    template: Option<String>,
138    /// A string or list of strings containing the license or licenses for the model
139    #[serde(skip_serializing_if = "Option::is_none")]
140    license: Option<Vec<String>>,
141    /// A string containing the system prompt for the model
142    #[serde(skip_serializing_if = "Option::is_none")]
143    system: Option<String>,
144    /// A dictionary of parameters for the model
145    #[serde(skip_serializing_if = "Option::is_none")]
146    parameters: Option<ModelOptions>,
147    /// A list of message objects used to create a conversation
148    #[serde(skip_serializing_if = "Option::is_none")]
149    messages: Option<Vec<ChatMessage>>,
150    stream: bool,
151    /// Quantize a non-quantized model
152    #[serde(skip_serializing_if = "Option::is_none")]
153    quantize: Option<QuantizationType>,
154}
155
156impl CreateModelRequest {
157    pub fn new(model_name: String) -> Self {
158        Self {
159            model_name,
160            from_model: None,
161            files: None,
162            adapters: None,
163            template: None,
164            license: None,
165            system: None,
166            parameters: None,
167            messages: None,
168            stream: false,
169            quantize: None,
170        }
171    }
172
173    pub fn from_model(mut self, from_model: String) -> Self {
174        self.from_model = Some(from_model);
175        self
176    }
177
178    pub fn files(mut self, files: std::collections::HashMap<String, String>) -> Self {
179        self.files = Some(files);
180        self
181    }
182
183    pub fn adapters(mut self, adapters: std::collections::HashMap<String, String>) -> Self {
184        self.adapters = Some(adapters);
185        self
186    }
187
188    pub fn template(mut self, template: String) -> Self {
189        self.template = Some(template);
190        self
191    }
192
193    pub fn license(mut self, license: String) -> Self {
194        self.license = Some(vec![license]);
195        self
196    }
197
198    pub fn licenses(mut self, licenses: Vec<String>) -> Self {
199        self.license = Some(licenses);
200        self
201    }
202
203    pub fn system(mut self, system: String) -> Self {
204        self.system = Some(system);
205        self
206    }
207
208    pub fn parameters(mut self, parameters: ModelOptions) -> Self {
209        self.parameters = Some(parameters);
210        self
211    }
212
213    pub fn messages(mut self, messages: Vec<ChatMessage>) -> Self {
214        self.messages = Some(messages);
215        self
216    }
217
218    pub fn quantize(mut self, quantize: QuantizationType) -> Self {
219        self.quantize = Some(quantize);
220        self
221    }
222}
223
224/// A create model status response from Ollama.
225#[derive(Deserialize, Debug)]
226pub struct CreateModelStatus {
227    #[serde(rename = "status")]
228    pub message: String,
229}