use crate::errors::{Result, SpiderError};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum LLMProviderKind {
OpenAI,
Anthropic,
OpenRouter,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMConfig {
pub provider: LLMProviderKind,
pub model: String,
pub api_key: String,
#[serde(default)]
pub base_url: Option<String>,
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(default)]
pub temperature: Option<f64>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum LLMRole {
System,
User,
Assistant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum LLMContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrlValue },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageUrlValue {
pub url: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum LLMContent {
Text(String),
Parts(Vec<LLMContentPart>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMMessage {
pub role: LLMRole,
pub content: LLMContent,
}
impl LLMMessage {
pub fn system(text: impl Into<String>) -> Self {
Self {
role: LLMRole::System,
content: LLMContent::Text(text.into()),
}
}
pub fn user(text: impl Into<String>) -> Self {
Self {
role: LLMRole::User,
content: LLMContent::Text(text.into()),
}
}
pub fn user_parts(parts: Vec<LLMContentPart>) -> Self {
Self {
role: LLMRole::User,
content: LLMContent::Parts(parts),
}
}
pub fn assistant(text: impl Into<String>) -> Self {
Self {
role: LLMRole::Assistant,
content: LLMContent::Text(text.into()),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ChatOptions {
pub json_mode: bool,
}
#[async_trait::async_trait]
pub trait LLMProvider: Send + Sync {
async fn chat(&self, messages: &[LLMMessage], options: Option<ChatOptions>) -> Result<String>;
}
pub async fn chat_json<T: serde::de::DeserializeOwned>(
llm: &dyn LLMProvider,
messages: &[LLMMessage],
) -> Result<T> {
let text = llm
.chat(messages, Some(ChatOptions { json_mode: true }))
.await?;
parse_json_response(&text)
}
pub fn parse_json_response<T: serde::de::DeserializeOwned>(text: &str) -> Result<T> {
if let Ok(val) = serde_json::from_str::<T>(text) {
return Ok(val);
}
if let Some(start) = text.find("```") {
let after_fence = &text[start + 3..];
let json_start = after_fence
.find('\n')
.map(|i| i + 1)
.unwrap_or(0);
if let Some(end) = after_fence[json_start..].find("```") {
let json_str = &after_fence[json_start..json_start + end];
if let Ok(val) = serde_json::from_str::<T>(json_str.trim()) {
return Ok(val);
}
}
}
Err(SpiderError::Llm(format!(
"LLM response is not valid JSON: {}",
&text[..text.len().min(200)]
)))
}
pub fn create_provider(config: LLMConfig) -> Box<dyn LLMProvider> {
match config.provider {
LLMProviderKind::OpenAI | LLMProviderKind::OpenRouter => {
Box::new(crate::ai::providers::openai::OpenAICompatibleProvider::new(config))
}
LLMProviderKind::Anthropic => {
Box::new(crate::ai::providers::anthropic::AnthropicProvider::new(config))
}
}
}