larpshell 0.1.1

Ctrl+C then Ctrl+V is simply too much work. Just let the LLMs rule your terminal directly!!
use crate::config::{OpenAIConfig, OpenRouterConfig};
use crate::error::LarpshellError;
use crate::providers::AIProvider;
use crate::providers::base::{BaseProvider, strip_url_for_display};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};

const DEFAULT_TEMPERATURE: f32 = 0.7;

pub struct OpenAICompatibleProvider {
    base: BaseProvider,
    base_url: String,
    api_key: Option<String>,
    model: String,
    provider_slug: &'static str,
    display_name: &'static str,
}

pub struct OpenAIProvider {
    inner: OpenAICompatibleProvider,
}

pub struct OpenRouterProvider {
    inner: OpenAICompatibleProvider,
}

#[derive(Serialize)]
struct ChatRequest {
    model: String,
    messages: Vec<Message>,
    temperature: f32,
}

#[derive(Serialize)]
struct Message {
    role: String,
    content: String,
}

#[derive(Deserialize)]
struct ChatResponse {
    choices: Vec<Choice>,
}

#[derive(Deserialize)]
struct Choice {
    message: MessageResponse,
}

#[derive(Deserialize)]
struct MessageResponse {
    content: String,
}

impl OpenAICompatibleProvider {
    fn new(
        base_url: String,
        api_key: Option<String>,
        model: String,
        provider_slug: &'static str,
        display_name: &'static str,
    ) -> Result<Self, LarpshellError> {
        Ok(Self {
            base: BaseProvider::new()?,
            base_url,
            api_key,
            model,
            provider_slug,
            display_name,
        })
    }

    async fn generate(&self, prompt: &str) -> Result<String, LarpshellError> {
        let normalized_base_url = self.base_url.trim_end_matches('/');
        let url = if normalized_base_url.ends_with("/v1") {
            format!("{}/chat/completions", normalized_base_url)
        } else {
            format!("{}/v1/chat/completions", normalized_base_url)
        };

        let request_body = ChatRequest {
            model: self.model.clone(),
            messages: vec![Message {
                role: "user".to_string(),
                content: prompt.to_string(),
            }],
            temperature: DEFAULT_TEMPERATURE,
        };

        let mut request = self.base.client.post(&url).json(&request_body);
        if let Some(ref api_key) = self.api_key {
            request = request.header("Authorization", format!("Bearer {}", api_key));
        }

        let response = request
            .send()
            .await
            .map_err(|e| LarpshellError::from_reqwest(e, self.provider_slug))?;

        let response = BaseProvider::check_response(response, self.provider_slug).await?;

        let chat_response: ChatResponse = response
            .json()
            .await
            .map_err(|e| LarpshellError::InvalidResponse(e.to_string()))?;

        let content = chat_response
            .choices
            .first()
            .map(|c| c.message.content.clone())
            .ok_or_else(|| {
                LarpshellError::InvalidResponse(format!("no response from {}", self.provider_slug))
            })?;

        Ok(content)
    }

    fn name(&self) -> String {
        format!(
            "{} ({})",
            self.display_name,
            strip_url_for_display(&self.base_url)
        )
    }
}

impl OpenAIProvider {
    pub fn new(config: &OpenAIConfig) -> Result<Self, LarpshellError> {
        Ok(Self {
            inner: OpenAICompatibleProvider::new(
                config.base_url.clone(),
                config.api_key.clone(),
                config.model.clone(),
                "openai",
                "OpenAI",
            )?,
        })
    }
}

impl OpenRouterProvider {
    pub fn new(config: &OpenRouterConfig) -> Result<Self, LarpshellError> {
        Ok(Self {
            inner: OpenAICompatibleProvider::new(
                config.base_url.clone(),
                config.api_key.clone(),
                config.model.clone(),
                "openrouter",
                "OpenRouter",
            )?,
        })
    }
}

#[async_trait]
impl AIProvider for OpenAIProvider {
    async fn generate(&self, prompt: &str) -> Result<String, LarpshellError> {
        self.inner.generate(prompt).await
    }

    fn name(&self) -> String {
        self.inner.name()
    }
}

#[async_trait]
impl AIProvider for OpenRouterProvider {
    async fn generate(&self, prompt: &str) -> Result<String, LarpshellError> {
        self.inner.generate(prompt).await
    }

    fn name(&self) -> String {
        self.inner.name()
    }
}