langchain_rust/llm/ollama/
openai.rs

1use async_openai::config::Config;
2use reqwest::header::HeaderMap;
3use secrecy::Secret;
4use serde::Deserialize;
5
6const OLLAMA_API_BASE: &str = "http://localhost:11434/v1";
7
8/// Ollama has [OpenAI compatiblity](https://ollama.com/blog/openai-compatibility), meaning that you can use it as an OpenAI API.
9///
10/// This struct implements the `Config` trait of OpenAI, and has the necessary setup for OpenAI configurations for you to use Ollama.
11///
12/// ## Example
13///
14/// ```rs
15/// let ollama = OpenAI::new(OllamaConfig::default()).with_model("llama3.2");
16/// let response = ollama.invoke("Say hello!").await.unwrap();
17/// ```
18#[derive(Clone, Debug, Deserialize)]
19#[serde(default)]
20pub struct OllamaConfig {
21    api_base: String,
22    api_key: Secret<String>,
23}
24
25impl OllamaConfig {
26    pub fn new() -> Self {
27        Self::default()
28    }
29
30    pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
31        self.api_key = Secret::from(api_key.into());
32        self
33    }
34
35    pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
36        self.api_base = api_base.into();
37        self
38    }
39}
40
41impl Config for OllamaConfig {
42    fn api_key(&self) -> &Secret<String> {
43        &self.api_key
44    }
45
46    fn api_base(&self) -> &str {
47        &self.api_base
48    }
49
50    fn headers(&self) -> HeaderMap {
51        HeaderMap::default()
52    }
53
54    fn query(&self) -> Vec<(&str, &str)> {
55        vec![]
56    }
57
58    fn url(&self, path: &str) -> String {
59        format!("{}{}", self.api_base(), path)
60    }
61}
62
63impl Default for OllamaConfig {
64    fn default() -> Self {
65        Self {
66            api_base: OLLAMA_API_BASE.to_string(),
67            api_key: Secret::new("ollama".to_string()),
68        }
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75    use crate::{language_models::llm::LLM, llm::openai::OpenAI, schemas::Message};
76    use tokio::io::AsyncWriteExt;
77    use tokio_stream::StreamExt;
78
79    #[tokio::test]
80    #[ignore]
81    async fn test_ollama_openai() {
82        let ollama = OpenAI::new(OllamaConfig::default()).with_model("llama2");
83        let response = ollama.invoke("hola").await.unwrap();
84        println!("{}", response);
85    }
86
87    #[tokio::test]
88    #[ignore]
89    async fn test_ollama_openai_stream() {
90        let ollama = OpenAI::new(OllamaConfig::default()).with_model("phi3");
91
92        let message = Message::new_human_message("Why does water boil at 100 degrees?");
93        let mut stream = ollama.stream(&[message]).await.unwrap();
94        let mut stdout = tokio::io::stdout();
95        while let Some(res) = stream.next().await {
96            let data = res.unwrap();
97            stdout.write(data.content.as_bytes()).await.unwrap();
98        }
99        stdout.flush().await.unwrap();
100    }
101}