ragit_api/
api_provider.rs

1use crate::error::Error;
2use crate::model::TestModel;
3use crate::response::{
4    AnthropicResponse,
5    CohereResponse,
6    GoogleResponse,
7    IntoChatResponse,
8    OpenAiResponse,
9};
10use std::fmt;
11
12#[derive(Clone, Debug, Eq, Hash, PartialEq)]
13pub enum ApiProvider {
14    OpenAi { url: String },
15    Cohere,
16    Anthropic,
17    Google,
18
19    /// for test
20    /// 1. doesn't require api key
21    /// 2. needs no network
22    Test(TestModel),
23}
24
25impl ApiProvider {
26    // TODO: why `XXXResponse` -> `Box<dyn IntoChatResponse>` -> `Response`?
27    //       why not just `XXXResponse` -> `Response`?
28    pub fn parse_chat_response(&self, s: &str) -> Result<Box<dyn IntoChatResponse>, Error> {
29        match self {
30            ApiProvider::Anthropic => Ok(Box::new(serde_json::from_str::<AnthropicResponse>(s)?)),
31            ApiProvider::Cohere => Ok(Box::new(serde_json::from_str::<CohereResponse>(s)?)),
32            ApiProvider::OpenAi { .. } => Ok(Box::new(serde_json::from_str::<OpenAiResponse>(s)?)),
33            ApiProvider::Google => Ok(Box::new(serde_json::from_str::<GoogleResponse>(s)?)),
34            ApiProvider::Test(_) => unreachable!(),
35        }
36    }
37
38    pub fn parse(s: &str, url: &Option<String>) -> Result<Self, Error> {
39        match s.to_ascii_lowercase().replace(" ", "").replace("-", "").as_str() {
40            "openai" => match url {
41                Some(url) => Ok(ApiProvider::OpenAi { url: url.to_string() }),
42                None => Ok(ApiProvider::OpenAi { url: String::from("https://api.openai.com/v1/chat/completions") }),
43            },
44            "cohere" => Ok(ApiProvider::Cohere),
45            "anthropic" => Ok(ApiProvider::Anthropic),
46            "google" => Ok(ApiProvider::Google),
47            _ => Err(Error::InvalidApiProvider(s.to_string())),
48        }
49    }
50}
51
52impl fmt::Display for ApiProvider {
53    fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
54        write!(
55            fmt,
56            "{}",
57            match self {
58                ApiProvider::OpenAi { .. } => "openai",
59                ApiProvider::Cohere => "cohere",
60                ApiProvider::Anthropic => "anthropic",
61                ApiProvider::Google => "google",
62                ApiProvider::Test(_) => "test",
63            },
64        )
65    }
66}