model_gateway_rs/sdk/
ollama.rs

1use async_trait::async_trait;
2use toolcraft_request::{ByteStream, Request};
3
4use crate::{
5    error::Result,
6    model::{
7        llm::LlmInput,
8        ollama::{OllamaChatOptions, OllamaChatRequest, OllamaChatResponse},
9    },
10    sdk::ModelSDK,
11};
12
13/// ChatCompletion client using your wrapped Request.
14pub struct OllamaSdk {
15    request: Request,
16    model: String,
17}
18
19impl OllamaSdk {
20    pub fn new(base_url: &str, model: &str) -> Result<Self> {
21        let mut request = Request::new()?;
22        request.set_base_url(base_url)?;
23        request.set_default_headers(vec![("Content-Type", "application/json".to_string())])?;
24        Ok(Self {
25            request,
26            model: model.to_string(),
27        })
28    }
29}
30
31#[async_trait]
32impl ModelSDK for OllamaSdk {
33    type Input = LlmInput;
34    type Output = OllamaChatResponse;
35
36    /// Send a chat request and get full response.
37    async fn chat_once(&self, input: Self::Input) -> Result<Self::Output> {
38        let options = OllamaChatOptions {
39            num_predict: input.max_tokens,
40            temperature: None,
41        };
42        let body = OllamaChatRequest {
43            model: self.model.clone(),
44            messages: input.messages,
45            stream: Some(false),
46            options: Some(options),
47        };
48        let payload = serde_json::to_value(body)?;
49        let response = self.request.post("chat", &payload, None).await?;
50        let json: OllamaChatResponse = response.json().await?;
51        Ok(json)
52    }
53
54    /// Send a chat request and get response stream (SSE).
55    async fn chat_stream(&self, input: Self::Input) -> Result<ByteStream> {
56        let options = OllamaChatOptions {
57            num_predict: input.max_tokens,
58            temperature: None,
59        };
60        let body = OllamaChatRequest {
61            model: self.model.clone(),
62            messages: input.messages,
63            stream: Some(true),
64            options: Some(options),
65        };
66        let payload = serde_json::to_value(body)?;
67        let r = self.request.post_stream("chat", &payload, None).await?;
68        Ok(r)
69    }
70}