ollama_rs_mangle_fork/models/
create.rs

1use serde::{Deserialize, Serialize};
2
3use crate::Ollama;
4
5/// A stream of `CreateModelStatus` objects
6#[cfg(feature = "stream")]
7pub type CreateModelStatusStream =
8    std::pin::Pin<Box<dyn tokio_stream::Stream<Item = crate::error::Result<CreateModelStatus>>>>;
9
10impl Ollama {
11    #[cfg(feature = "stream")]
12    /// Create a model with streaming, meaning that each new status will be streamed.
13    pub async fn create_model_stream(
14        &self,
15        model_name: String,
16        path: String,
17    ) -> crate::error::Result<CreateModelStatusStream> {
18        use tokio_stream::StreamExt;
19
20        use crate::error::OllamaError;
21
22        let request = CreateModelRequest {
23            model_name,
24            path,
25            stream: true,
26        };
27
28        let uri = format!("{}/api/create", self.uri());
29        let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?;
30        let res = self
31            .reqwest_client
32            .post(uri)
33            .body(serialized)
34            .send()
35            .await
36            .map_err(|e| e.to_string())?;
37
38        if !res.status().is_success() {
39            return Err(res.text().await.unwrap_or_else(|e| e.to_string()).into());
40        }
41
42        let stream = Box::new(res.bytes_stream().map(|res| match res {
43            Ok(bytes) => {
44                let res = serde_json::from_slice::<CreateModelStatus>(&bytes);
45                match res {
46                    Ok(res) => Ok(res),
47                    Err(e) => {
48                        let err = serde_json::from_slice::<crate::error::OllamaError>(&bytes);
49                        match err {
50                            Ok(err) => Err(err),
51                            Err(_) => Err(OllamaError::from(format!(
52                                "Failed to deserialize response: {}",
53                                e
54                            ))),
55                        }
56                    }
57                }
58            }
59            Err(e) => Err(OllamaError::from(format!("Failed to read response: {}", e))),
60        }));
61
62        Ok(std::pin::Pin::from(stream))
63    }
64
65    /// Create a model with a single response, only the final status will be returned.
66    pub async fn create_model(
67        &self,
68        model_name: String,
69        path: String,
70    ) -> crate::error::Result<CreateModelStatus> {
71        let request = CreateModelRequest {
72            model_name,
73            path,
74            stream: false,
75        };
76
77        let uri = format!("{}/api/create", self.uri());
78        let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?;
79        let res = self
80            .reqwest_client
81            .post(uri)
82            .body(serialized)
83            .send()
84            .await
85            .map_err(|e| e.to_string())?;
86
87        if !res.status().is_success() {
88            return Err(res.text().await.unwrap_or_else(|e| e.to_string()).into());
89        }
90
91        let res = res.bytes().await.map_err(|e| e.to_string())?;
92        let res = serde_json::from_slice::<CreateModelStatus>(&res).map_err(|e| e.to_string())?;
93
94        Ok(res)
95    }
96}
97
98/// A create model request to Ollama.
99#[derive(Serialize)]
100struct CreateModelRequest {
101    #[serde(rename = "name")]
102    model_name: String,
103    path: String,
104    stream: bool,
105}
106
107/// A create model status response from Ollama.
108#[derive(Deserialize, Debug)]
109pub struct CreateModelStatus {
110    #[serde(rename = "status")]
111    pub message: String,
112}