ragit_api/
api_provider.rs

1use crate::error::Error;
2use crate::model::TestModel;
3use crate::response::{
4    AnthropicResponse,
5    CohereResponse,
6    IntoChatResponse,
7    OpenAiResponse,
8};
9use std::fmt;
10
11#[derive(Clone, Debug, Eq, Hash, PartialEq)]
12pub enum ApiProvider {
13    OpenAi { url: String },
14    Cohere,
15    Anthropic,
16
17    /// for test
18    /// 1. doesn't require api key
19    /// 2. needs no network
20    Test(TestModel),
21}
22
23impl ApiProvider {
24    // TODO: why `XXXResponse` -> `Box<dyn IntoChatResponse>` -> `Response`?
25    //       why not just `XXXResponse` -> `Response`?
26    pub fn parse_chat_response(&self, s: &str) -> Result<Box<dyn IntoChatResponse>, Error> {
27        match self {
28            ApiProvider::Anthropic => Ok(Box::new(serde_json::from_str::<AnthropicResponse>(s)?)),
29            ApiProvider::Cohere => Ok(Box::new(serde_json::from_str::<CohereResponse>(s)?)),
30            ApiProvider::OpenAi { .. } => Ok(Box::new(serde_json::from_str::<OpenAiResponse>(s)?)),
31            ApiProvider::Test(_) => unreachable!(),
32        }
33    }
34
35    pub fn parse(s: &str, url: &Option<String>) -> Result<Self, Error> {
36        match s.to_ascii_lowercase().replace(" ", "").replace("-", "").as_str() {
37            "openai" => match url {
38                Some(url) => Ok(ApiProvider::OpenAi { url: url.to_string() }),
39                None => Ok(ApiProvider::OpenAi { url: String::from("https://api.openai.com/v1/chat/completions") }),
40            },
41            "cohere" => Ok(ApiProvider::Cohere),
42            "anthropic" => Ok(ApiProvider::Anthropic),
43            _ => Err(Error::InvalidApiProvider(s.to_string())),
44        }
45    }
46
47    pub fn get_api_url(&self) -> &str {
48        match self {
49            ApiProvider::Anthropic => "https://api.anthropic.com/v1/messages",
50            ApiProvider::Cohere => "https://api.cohere.com/v2/chat",
51            ApiProvider::OpenAi { url } => url,
52            ApiProvider::Test(_) => "",
53        }
54    }
55}
56
57impl fmt::Display for ApiProvider {
58    fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
59        write!(
60            fmt,
61            "{}",
62            match self {
63                ApiProvider::OpenAi { .. } => "openai",
64                ApiProvider::Cohere => "cohere",
65                ApiProvider::Anthropic => "anthropic",
66                ApiProvider::Test(_) => "test",
67            },
68        )
69    }
70}