model_gateway_rs/sdk/
ollama.rs1use async_trait::async_trait;
2use toolcraft_request::{ByteStream, Request};
3
4use crate::{
5 error::Result,
6 model::{
7 llm::LlmInput,
8 ollama::{OllamaChatOptions, OllamaChatRequest, OllamaChatResponse},
9 },
10 sdk::ModelSDK,
11};
12
13pub struct OllamaSdk {
15 request: Request,
16 model: String,
17}
18
19impl OllamaSdk {
20 pub fn new(base_url: &str, model: &str) -> Result<Self> {
21 let mut request = Request::new()?;
22 request.set_base_url(base_url)?;
23 request.set_default_headers(vec![("Content-Type", "application/json".to_string())])?;
24 Ok(Self {
25 request,
26 model: model.to_string(),
27 })
28 }
29}
30
31#[async_trait]
32impl ModelSDK for OllamaSdk {
33 type Input = LlmInput;
34 type Output = OllamaChatResponse;
35
36 async fn chat_once(&self, input: Self::Input) -> Result<Self::Output> {
38 let options = OllamaChatOptions {
39 num_predict: input.max_tokens,
40 temperature: None,
41 };
42 let body = OllamaChatRequest {
43 model: self.model.clone(),
44 messages: input.messages,
45 stream: Some(false),
46 options: Some(options),
47 };
48 let payload = serde_json::to_value(body)?;
49 let response = self.request.post("chat", &payload, None).await?;
50 let json: OllamaChatResponse = response.json().await?;
51 Ok(json)
52 }
53
54 async fn chat_stream(&self, input: Self::Input) -> Result<ByteStream> {
56 let options = OllamaChatOptions {
57 num_predict: input.max_tokens,
58 temperature: None,
59 };
60 let body = OllamaChatRequest {
61 model: self.model.clone(),
62 messages: input.messages,
63 stream: Some(true),
64 options: Some(options),
65 };
66 let payload = serde_json::to_value(body)?;
67 let r = self.request.post_stream("chat", &payload, None).await?;
68 Ok(r)
69 }
70}