shellchat 1.0.39

Transforms natural language into shell commands for execution or explanation.
use async_trait::async_trait;
use reqwest::Client;
use serde::Deserialize;
use serde_json::json;
use std::panic::panic_any;
use std::sync::Arc;
use thiserror::Error;

#[derive(Debug, Error)]
pub enum ProviderError {
    #[error("HTTP Request Error: {0}")]
    RequestError(#[from] reqwest::Error),
    #[error("JSON Parsing Error: {0}")]
    JsonError(#[from] serde_json::Error),
    #[error("Unexpected Response Structure")]
    UnexpectedResponse(String),
}

#[async_trait]
pub trait ProviderApi {
    async fn call(&self, role_prompt: &str, user_prompt: &str) -> Result<String, ProviderError>;
}

#[derive(Debug, Deserialize, Clone)]
#[serde(tag = "type")]
pub enum ProviderConfig {
    OpenAI {
        api_key: String,
        api_url: String,
        model: String,
    },
    AzureOpenAI {
        api_key: String,
        api_url: String,
        model: String,
    },
    Ollama {
        api_key: String,
        api_url: String,
        model: String,
    },
}

#[derive(Clone)]
pub struct AzureOpenAI {
    pub client: Client,
    pub model: String,
    pub api_key: String,
    pub url_full: String,
}

#[derive(serde::Deserialize)]
pub struct CompletionResponse {
    choices: Vec<Choice>,
}

#[derive(serde::Deserialize)]
pub struct Choice {
    message: Message,
}

#[derive(serde::Deserialize)]
pub struct Message {
    content: String,
}

#[async_trait]
impl ProviderApi for AzureOpenAI {
    async fn call(&self, role_prompt: &str, user_prompt: &str) -> Result<String, ProviderError> {
        let messages = vec![
            json!({ "role": "system", "content": role_prompt }),
            json!({ "role": "user", "content": user_prompt }),
        ];

        let body = json!({
            "model": &self.model,
            "messages": messages,
        });

        let response = self
            .client
            .post(&self.url_full)
            .header("api-key", &self.api_key)
            .json(&body)
            .send()
            .await?
            .text()
            .await?;

        match serde_json::from_str::<CompletionResponse>(&response) {
            Ok(resp) => {
                if let Some(choice) = resp.choices.first() {
                    Ok(choice.message.content.clone())
                } else {
                    Err(ProviderError::UnexpectedResponse(response.to_string()))
                }
            }
            Err(_) => Err(ProviderError::UnexpectedResponse(response.to_string())),
        }
    }
}

pub fn new_provider(provider_type: &ProviderConfig) -> Arc<dyn ProviderApi + Send + Sync> {
    let client = Client::new();
    let provider: Arc<dyn ProviderApi + Send + Sync> = match provider_type {
        ProviderConfig::AzureOpenAI {
            api_key,
            api_url: base_url,
            model,
        } => Arc::new(AzureOpenAI {
            client,
            model: model.clone(),
            api_key: api_key.clone(),
            url_full: format!(
                "{}/openai/deployments/{}/chat/completions?api-version=2024-02-01",
                &base_url, &model,
            ),
        }),
        _ => panic_any(format!(
            "the provider not implemented yet: {:?}",
            provider_type
        )),
    };
    provider
}

#[cfg(test)]
mod tests {
    use super::*;

    struct MockProvider;

    #[async_trait::async_trait]
    impl ProviderApi for MockProvider {
        async fn call(
            &self,
            _role_prompt: &str,
            _user_prompt: &str,
        ) -> Result<String, ProviderError> {
            Ok("Mock response".to_string())
        }
    }

    #[tokio::test]
    async fn test_provider_call() {
        let provider = MockProvider;
        let response = provider.call("", "test").await;
        assert!(response.is_ok());
        assert_eq!(response.unwrap(), "Mock response");
    }
}