use crate::config::{AiConfig, Framework, Question};
use crate::error::{AppError, Result};
use ollama_rs::{
generation::chat::{request::ChatMessageRequest, ChatMessage, MessageRole},
Ollama,
};
use reqwest::header::{AUTHORIZATION, CONTENT_TYPE};
use serde_json::Value;
use std::env;
async fn get_openai_response(question: Question, ai_config: &AiConfig) -> Result<String> {
let api_key = env::var("OPENAI_API_KEY").map_err(|e| AppError::ApiError {
model_name: ai_config.llm.to_string(),
failure_str: format!("Missing or invalid OPENAI_API_KEY: {}", e),
})?;
let mut messages = vec![];
if let Some(sys_prompt) = &question.system_prompt {
messages.push(serde_json::json!({
"role": "system",
"content": sys_prompt
}));
} else {
messages.push(serde_json::json!({
"role": "system",
"content": ""
}));
}
if let Some(prev_messages) = question.messages {
for msg in prev_messages.iter() {
if !msg.content.is_empty() {
messages.push(serde_json::json!({
"role": "user",
"content": msg.content
}));
}
if !msg.output.is_empty() {
messages.push(serde_json::json!({
"role": "assistant",
"content": msg.output
}));
}
}
}
let usr_input = if question.new_prompt.is_empty() {
".".to_string()
} else {
question.new_prompt
};
messages.push(serde_json::json!({
"role": "user",
"content": usr_input
}));
let payload = serde_json::json!({
"model": ai_config.model,
"messages": messages
});
let api_url = env::var("OPENAI_API_URL")
.unwrap_or_else(|_| "https://api.openai.com/v1/chat/completions".to_string());
let resp = reqwest::Client::new()
.post(&api_url)
.header(CONTENT_TYPE, "application/json")
.header(AUTHORIZATION, format!("Bearer {}", api_key))
.json(&payload)
.send()
.await
.map_err(|e| AppError::ApiError {
model_name: ai_config.llm.to_string(),
failure_str: format!("Request error: {}", e),
})?;
if !resp.status().is_success() {
let status = resp.status();
let err_body = resp.text().await.unwrap_or_default();
return Err(AppError::ApiError {
model_name: ai_config.llm.to_string(),
failure_str: format!("Status {}: {}", status, err_body),
});
}
let response: Value = resp.json().await.map_err(|e| AppError::ModelError {
model_name: ai_config.model.to_string(),
failure_str: format!("Failed to parse JSON response: {}", e),
})?;
let answer = response["choices"][0]["message"]["content"]
.as_str()
.ok_or_else(|| AppError::ModelError {
model_name: ai_config.model.to_string(),
failure_str: "Failed to extract content from OpenAI response".to_string(),
})?
.to_string();
Ok(answer)
}
async fn get_anthropic_response(question: Question, ai_config: &AiConfig) -> Result<String> {
let api_key = env::var("ANTHROPIC_API_KEY").map_err(|e| AppError::ApiError {
model_name: ai_config.llm.to_string(),
failure_str: format!("Missing or invalid ANTHROPIC_API_KEY: {}", e),
})?;
let mut messages = vec![];
if let Some(prev_messages) = question.messages {
for msg in prev_messages.iter() {
if !msg.content.is_empty() {
messages.push(serde_json::json!({
"role": "user",
"content": [{"type": "text", "text": msg.content}]
}));
}
if !msg.output.is_empty() {
messages.push(serde_json::json!({
"role": "assistant",
"content": [{"type": "text", "text": msg.output}]
}));
}
}
}
let usr_input = if question.new_prompt.is_empty() {
".".to_string()
} else {
question.new_prompt
};
messages.push(serde_json::json!({
"role": "user",
"content": [{"type": "text", "text": usr_input}]
}));
let system_prompt = question.system_prompt.unwrap_or_else(|| {
"You are a helpful assistant. Answer the question concisely.".to_string()
});
let max_tokens = ai_config.max_token.unwrap_or(1024);
let payload = serde_json::json!({
"model": ai_config.model,
"max_tokens": max_tokens,
"messages": messages,
"system": system_prompt
});
let api_url = env::var("ANTHROPIC_API_URL")
.unwrap_or_else(|_| "https://api.anthropic.com/v1/messages".to_string());
let resp = reqwest::Client::new()
.post(&api_url)
.header("x-api-key", api_key)
.header("anthropic-version", "2023-06-01")
.header(CONTENT_TYPE, "application/json")
.json(&payload)
.send()
.await
.map_err(|e| AppError::ApiError {
model_name: ai_config.llm.to_string(),
failure_str: format!("Request error: {}", e),
})?;
if !resp.status().is_success() {
let status = resp.status();
let err_body = resp.text().await.unwrap_or_default();
return Err(AppError::ApiError {
model_name: ai_config.llm.to_string(),
failure_str: format!("Status {}: {}", status, err_body),
});
}
let response: Value = resp.json().await.map_err(|e| AppError::ModelError {
model_name: ai_config.model.to_string(),
failure_str: format!("Failed to parse JSON response: {}", e),
})?;
let answer = response["content"][0]["text"]
.as_str()
.ok_or_else(|| AppError::ModelError {
model_name: ai_config.model.to_string(),
failure_str: "Failed to extract content from Anthropic response".to_string(),
})?
.to_string();
Ok(answer)
}
async fn get_ollama_response(question: Question, ai_config: &AiConfig) -> Result<String> {
let mut ollama = Ollama::default();
let mut msgs = vec![];
if question.system_prompt.is_some() {
msgs.push(ChatMessage {
role: MessageRole::System,
content: question.system_prompt.unwrap(),
tool_calls: vec![],
images: None,
});
} else {
let default_sys_prompt =
String::from("You are helpful assistant. Answer the question consicely.");
msgs.push(ChatMessage {
role: MessageRole::System,
content: default_sys_prompt,
tool_calls: vec![],
images: None,
});
}
if question.messages.is_some() {
for msg in question.messages.unwrap().iter() {
if !msg.content.is_empty() {
msgs.push(ChatMessage {
role: MessageRole::User,
content: msg.content.to_owned(),
tool_calls: vec![],
images: None,
});
}
if !msg.output.is_empty() {
msgs.push(ChatMessage {
role: MessageRole::Assistant,
content: msg.output.to_owned(),
tool_calls: vec![],
images: None,
});
}
}
}
if question.new_prompt.is_empty() {
msgs.push(ChatMessage {
role: MessageRole::User,
content: String::from("."),
tool_calls: vec![],
images: None,
});
} else {
msgs.push(ChatMessage {
role: MessageRole::User,
content: question.new_prompt.to_owned(),
tool_calls: vec![],
images: None,
});
}
let req = ChatMessageRequest::new(ai_config.model.to_owned(), msgs.to_owned());
let result = ollama
.send_chat_messages_with_history(&mut msgs, req)
.await
.map_err(|e| AppError::ModelError {
model_name: ai_config.model.to_owned(),
failure_str: e.to_string(),
})?;
let answer = result.message.content;
Ok(answer)
}
pub async fn ask_question(ai_config: &AiConfig, question: Question) -> Result<String> {
match ai_config.llm {
Framework::OpenAI => get_openai_response(question, ai_config).await,
Framework::Anthropic => get_anthropic_response(question, ai_config).await,
Framework::Ollama => get_ollama_response(question, ai_config).await,
}
}