use futures::stream::{BoxStream, StreamExt};
use ollama_rs::generation::chat::request::ChatMessageRequest;
use ollama_rs::generation::chat::ChatMessage;
use ollama_rs::Ollama;
use crate::error::{AgentError, Result};
use crate::models::LLM;
use crate::types::{File, GenerationChunk, GenerationResponse, Message, Role};
pub struct OllamaLLM {
client: Ollama,
model: String,
}
impl OllamaLLM {
pub fn new(model: impl Into<String>) -> Self {
Self {
client: Ollama::default(),
model: model.into(),
}
}
pub fn with_host(host: impl Into<String>, port: u16, model: impl Into<String>) -> Self {
Self {
client: Ollama::new(host.into(), port),
model: model.into(),
}
}
fn convert_role(role: &Role) -> String {
match role {
Role::System => "system".to_string(),
Role::User => "user".to_string(),
Role::Assistant => "assistant".to_string(),
Role::Tool => "user".to_string(),
}
}
fn convert_message(&self, msg: &Message) -> ChatMessage {
ChatMessage {
role: Self::convert_role(&msg.role),
content: msg.content.clone(),
images: None,
}
}
fn prepare_messages(&self, messages: Vec<Message>, files: Option<Vec<File>>) -> Vec<ChatMessage> {
let mut chat_messages: Vec<ChatMessage> =
messages.iter().map(|m| self.convert_message(m)).collect();
if let Some(files) = files {
let images: Vec<String> = files
.into_iter()
.filter(|f| f.mime_type.starts_with("image/"))
.map(|f| base64::engine::general_purpose::STANDARD.encode(&f.data))
.collect();
if !images.is_empty() {
if let Some(last_msg) = chat_messages.last_mut() {
last_msg.images = Some(images);
}
}
}
chat_messages
}
}
#[async_trait]
impl LLM for OllamaLLM {
async fn generate(
&self,
messages: Vec<Message>,
files: Option<Vec<File>>,
) -> Result<GenerationResponse> {
let chat_messages = self.prepare_messages(messages, files);
let request = ChatMessageRequest::new(self.model.clone(), chat_messages);
let response = self
.client
.send_chat_messages(request)
.await
.map_err(|e| AgentError::ModelError(format!("Ollama error: {}", e)))?;
Ok(GenerationResponse {
content: response.message.content,
metadata: None,
})
}
async fn stream_generate(
&self,
messages: Vec<Message>,
files: Option<Vec<File>>,
) -> Result<BoxStream<'static, Result<GenerationChunk>>> {
let chat_messages = self.prepare_messages(messages, files);
let request = ChatMessageRequest::new(self.model.clone(), chat_messages);
let stream = self
.client
.send_chat_messages_stream(request)
.await
.map_err(|e| AgentError::ModelError(format!("Ollama error: {}", e)))?;
let mapped = stream.map(|chunk_res| match chunk_res {
Ok(chunk) => {
let content = if let Some(msg) = chunk.message {
msg.content
} else {
String::new()
};
Ok(GenerationChunk {
content,
metadata: None,
})
}
Err(e) => Err(AgentError::ModelError(format!("Ollama stream error: {}", e))),
});
Ok(Box::pin(mapped))
}
fn model_name(&self) -> &str {
&self.model
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[ignore] async fn test_ollama_generate() {
let llm = OllamaLLM::new("llama2");
let messages = vec![Message {
role: Role::User,
content: "Say 'test' and nothing else.".to_string(),
metadata: None,
}];
let response = llm.generate(messages, None).await;
assert!(response.is_ok());
}
}