1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use serde::{Deserialize, Serialize};

use crate::Ollama;

/// A stream of `CreateModelStatus` objects
#[cfg(feature = "stream")]
pub type CreateModelStatusStream = std::pin::Pin<
    Box<dyn tokio_stream::Stream<Item = crate::error::Result<CreateModelStatus>> + Send>,
>;

impl Ollama {
    #[cfg(feature = "stream")]
    /// Create a model with streaming, meaning that each new status will be streamed.
    pub async fn create_model_stream(
        &self,
        mut request: CreateModelRequest,
    ) -> crate::error::Result<CreateModelStatusStream> {
        use tokio_stream::StreamExt;

        use crate::error::OllamaError;

        request.stream = true;

        let url = format!("{}api/create", self.url_str());
        let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?;
        let res = self
            .reqwest_client
            .post(url)
            .body(serialized)
            .send()
            .await
            .map_err(|e| e.to_string())?;

        if !res.status().is_success() {
            return Err(res.text().await.unwrap_or_else(|e| e.to_string()).into());
        }

        let stream = Box::new(res.bytes_stream().map(|res| match res {
            Ok(bytes) => {
                let res = serde_json::from_slice::<CreateModelStatus>(&bytes);
                match res {
                    Ok(res) => Ok(res),
                    Err(e) => {
                        let err = serde_json::from_slice::<crate::error::OllamaError>(&bytes);
                        match err {
                            Ok(err) => Err(err),
                            Err(_) => Err(OllamaError::from(format!(
                                "Failed to deserialize response: {}",
                                e
                            ))),
                        }
                    }
                }
            }
            Err(e) => Err(OllamaError::from(format!("Failed to read response: {}", e))),
        }));

        Ok(std::pin::Pin::from(stream))
    }

    /// Create a model with a single response, only the final status will be returned.
    pub async fn create_model(
        &self,
        request: CreateModelRequest,
    ) -> crate::error::Result<CreateModelStatus> {
        let url = format!("{}api/create", self.url_str());
        let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?;
        let res = self
            .reqwest_client
            .post(url)
            .body(serialized)
            .send()
            .await
            .map_err(|e| e.to_string())?;

        if !res.status().is_success() {
            return Err(res.text().await.unwrap_or_else(|e| e.to_string()).into());
        }

        let res = res.bytes().await.map_err(|e| e.to_string())?;
        let res = serde_json::from_slice::<CreateModelStatus>(&res).map_err(|e| e.to_string())?;

        Ok(res)
    }
}

/// A create model request to Ollama.
#[derive(Serialize)]
pub struct CreateModelRequest {
    #[serde(rename = "name")]
    model_name: String,
    path: Option<String>,
    modelfile: Option<String>,
    stream: bool,
}

impl CreateModelRequest {
    /// Create a model described in the Modelfile at `path`.
    pub fn path(model_name: String, path: String) -> Self {
        Self {
            model_name,
            path: Some(path),
            modelfile: None,
            stream: false,
        }
    }

    /// Create a model described by the Modelfile contents passed to `modelfile`.
    pub fn modelfile(model_name: String, modelfile: String) -> Self {
        Self {
            model_name,
            path: None,
            modelfile: Some(modelfile),
            stream: false,
        }
    }
}

/// A create model status response from Ollama.
#[derive(Deserialize, Debug)]
pub struct CreateModelStatus {
    #[serde(rename = "status")]
    pub message: String,
}