use anyhow::{Context, Result};
use async_trait::async_trait;
use reqwest::Client;
use serde_json;
use super::cloud::{build_openai_messages, build_openai_tools, parse_openai_response, OpenAIResponse};
use super::provider::ModelProvider;
use super::types::*;
#[derive(Debug, Clone)]
pub enum LocalRuntime {
Ollama,
LmStudio,
LlamaCpp,
Custom,
}
impl LocalRuntime {
fn completions_path(&self) -> &str {
match self {
Self::Ollama => "/v1/chat/completions",
Self::LmStudio => "/v1/chat/completions",
Self::LlamaCpp => "/v1/chat/completions",
Self::Custom => "/v1/chat/completions",
}
}
fn display_name(&self) -> &str {
match self {
Self::Ollama => "ollama",
Self::LmStudio => "lm_studio",
Self::LlamaCpp => "llama_cpp",
Self::Custom => "custom_local",
}
}
}
pub struct LocalModelProvider {
client: Client,
base_url: String,
model: String,
runtime: LocalRuntime,
supports_tool_calls: bool,
}
impl LocalModelProvider {
pub fn new(
base_url: String,
model: String,
runtime: LocalRuntime,
supports_tool_calls: bool,
) -> Self {
Self {
client: Client::new(),
base_url,
model,
runtime,
supports_tool_calls,
}
}
pub fn ollama(model: Option<String>) -> Self {
Self::new(
"http://localhost:11434".to_string(),
model.unwrap_or_else(|| "llama3".to_string()),
LocalRuntime::Ollama,
false,
)
}
pub fn lm_studio(model: Option<String>) -> Self {
Self::new(
"http://localhost:1234".to_string(),
model.unwrap_or_else(|| "local-model".to_string()),
LocalRuntime::LmStudio,
false,
)
}
pub fn llama_cpp(model: Option<String>) -> Self {
Self::new(
"http://localhost:8080".to_string(),
model.unwrap_or_else(|| "local-model".to_string()),
LocalRuntime::LlamaCpp,
false,
)
}
}
#[async_trait]
impl ModelProvider for LocalModelProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
let url = format!("{}{}", self.base_url, self.runtime.completions_path());
let messages = if self.supports_tool_calls {
build_openai_messages(&request)
} else {
build_text_only_messages(&request)
};
let mut body = serde_json::json!({
"model": self.model,
"messages": messages,
"max_tokens": request.max_tokens,
"stream": false,
});
if self.supports_tool_calls {
let tools = build_openai_tools(&request);
if !tools.is_empty() {
body["tools"] = serde_json::json!(tools);
}
}
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.with_context(|| {
format!(
"Failed to connect to local model at {}. Is {} running?",
self.base_url,
self.runtime.display_name()
)
})?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
anyhow::bail!(
"Local model error ({}) from {}: {}",
status,
self.runtime.display_name(),
error_text
);
}
let api_response: OpenAIResponse = response
.json()
.await
.context("Failed to parse response from local model")?;
parse_openai_response(api_response)
}
fn name(&self) -> &str {
self.runtime.display_name()
}
fn model_id(&self) -> &str {
&self.model
}
fn supports_tools(&self) -> bool {
self.supports_tool_calls
}
}
fn build_text_only_messages(request: &CompletionRequest) -> Vec<serde_json::Value> {
let mut messages = Vec::new();
if !request.system.is_empty() {
messages.push(serde_json::json!({
"role": "system",
"content": request.system,
}));
}
for msg in &request.messages {
let role = match msg.role {
crate::query::types::Role::User => "user",
crate::query::types::Role::Assistant => "assistant",
crate::query::types::Role::System => "system",
};
let mut content = String::new();
for block in &msg.content {
match block {
ContentBlock::Text { text } => {
content.push_str(text);
}
ContentBlock::ToolUse { id: _, name, input } => {
content.push_str(&format!(
"\n<tool_call>\n{{\n \"name\": \"{}\",\n \"input\": {}\n}}\n</tool_call>\n",
name, serde_json::to_string_pretty(input).unwrap_or_default()
));
}
ContentBlock::ToolResult { tool_use_id: _, content: result, is_error: _ } => {
content.push_str(&format!(
"\n<tool_result>\n{}\n</tool_result>\n",
result
));
}
}
}
messages.push(serde_json::json!({
"role": role,
"content": content,
}));
}
messages
}