use async_trait::async_trait;
use std::collections::HashMap;
use crate::error::Result;
use crate::puppet::{PromptRequest, PromptResponse};
use crate::session::Session;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Provider {
#[cfg(feature = "grok")]
Grok,
#[cfg(feature = "claude")]
Claude,
#[cfg(feature = "gemini")]
Gemini,
#[cfg(feature = "chatgpt")]
ChatGpt,
#[cfg(feature = "perplexity")]
Perplexity,
#[cfg(feature = "notebooklm")]
NotebookLm,
#[cfg(feature = "kaggle")]
Kaggle,
}
impl Provider {
pub fn name(&self) -> &'static str {
match self {
#[cfg(feature = "grok")]
Provider::Grok => "grok",
#[cfg(feature = "claude")]
Provider::Claude => "claude",
#[cfg(feature = "gemini")]
Provider::Gemini => "gemini",
#[cfg(feature = "chatgpt")]
Provider::ChatGpt => "chatgpt",
#[cfg(feature = "perplexity")]
Provider::Perplexity => "perplexity",
#[cfg(feature = "notebooklm")]
Provider::NotebookLm => "notebooklm",
#[cfg(feature = "kaggle")]
Provider::Kaggle => "kaggle",
}
}
pub fn base_url(&self) -> &'static str {
match self {
#[cfg(feature = "grok")]
Provider::Grok => "https://x.com/i/grok",
#[cfg(feature = "claude")]
Provider::Claude => "https://claude.ai",
#[cfg(feature = "gemini")]
Provider::Gemini => "https://gemini.google.com",
#[cfg(feature = "chatgpt")]
Provider::ChatGpt => "https://chat.openai.com",
#[cfg(feature = "perplexity")]
Provider::Perplexity => "https://www.perplexity.ai",
#[cfg(feature = "notebooklm")]
Provider::NotebookLm => "https://notebooklm.google.com",
#[cfg(feature = "kaggle")]
Provider::Kaggle => "https://www.kaggle.com",
}
}
pub fn from_string(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
#[cfg(feature = "grok")]
"grok" | "xai" | "x" => Some(Provider::Grok),
#[cfg(feature = "claude")]
"claude" | "anthropic" => Some(Provider::Claude),
#[cfg(feature = "gemini")]
"gemini" | "google" | "bard" => Some(Provider::Gemini),
#[cfg(feature = "chatgpt")]
"chatgpt" | "openai" | "gpt" => Some(Provider::ChatGpt),
#[cfg(feature = "perplexity")]
"perplexity" | "pplx" => Some(Provider::Perplexity),
#[cfg(feature = "notebooklm")]
"notebooklm" | "notebook" | "nlm" => Some(Provider::NotebookLm),
#[cfg(feature = "kaggle")]
"kaggle" => Some(Provider::Kaggle),
_ => None,
}
}
pub fn all() -> Vec<Provider> {
vec![
#[cfg(feature = "grok")]
Provider::Grok,
#[cfg(feature = "claude")]
Provider::Claude,
#[cfg(feature = "gemini")]
Provider::Gemini,
#[cfg(feature = "chatgpt")]
Provider::ChatGpt,
#[cfg(feature = "perplexity")]
Provider::Perplexity,
#[cfg(feature = "notebooklm")]
Provider::NotebookLm,
#[cfg(feature = "kaggle")]
Provider::Kaggle,
]
}
pub fn search_providers() -> Vec<Provider> {
vec![
#[cfg(feature = "grok")]
Provider::Grok,
#[cfg(feature = "perplexity")]
Provider::Perplexity,
#[cfg(feature = "chatgpt")]
Provider::ChatGpt,
#[cfg(feature = "kaggle")]
Provider::Kaggle,
]
}
pub fn large_context_providers() -> Vec<Provider> {
vec![
#[cfg(feature = "notebooklm")]
Provider::NotebookLm, #[cfg(feature = "gemini")]
Provider::Gemini, #[cfg(feature = "claude")]
Provider::Claude, #[cfg(feature = "chatgpt")]
Provider::ChatGpt, #[cfg(feature = "grok")]
Provider::Grok, ]
}
}
impl std::str::FromStr for Provider {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
Self::from_string(s).ok_or_else(|| format!("Unknown provider: {}", s))
}
}
impl std::fmt::Display for Provider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone, Default)]
pub struct ProviderCapabilities {
pub conversation: bool,
pub vision: bool,
pub file_upload: bool,
pub code_execution: bool,
pub web_search: bool,
pub max_context: Option<usize>,
pub models: Vec<String>,
}
#[async_trait]
pub trait ProviderTrait: Send + Sync {
fn provider(&self) -> Provider;
fn capabilities(&self) -> ProviderCapabilities;
async fn is_authenticated(&self, session: &Session) -> Result<bool>;
async fn authenticate(&self, session: &mut Session) -> Result<()>;
async fn send_prompt(
&self,
session: &Session,
request: &PromptRequest,
) -> Result<PromptResponse>;
async fn new_conversation(&self, session: &Session) -> Result<String>;
async fn continue_conversation(
&self,
session: &Session,
conversation_id: &str,
request: &PromptRequest,
) -> Result<PromptResponse>;
async fn current_url(&self, session: &Session) -> Result<String>;
async fn wait_ready(&self, session: &Session) -> Result<()>;
async fn extract_response(&self, session: &Session) -> Result<String>;
async fn check_rate_limit(&self, session: &Session) -> Result<Option<std::time::Duration>>;
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct ProviderMetadata {
pub provider: Provider,
pub display_name: String,
pub description: String,
pub docs_url: Option<String>,
pub tos_url: Option<String>,
pub selectors: HashMap<String, String>,
}
impl ProviderMetadata {
#[allow(dead_code)]
pub fn new(provider: Provider) -> Self {
match provider {
#[cfg(feature = "grok")]
Provider::Grok => Self {
provider,
display_name: "Grok".into(),
description: "X.ai's conversational AI".into(),
docs_url: Some("https://x.ai".into()),
tos_url: Some("https://x.com/tos".into()),
selectors: HashMap::new(),
},
#[cfg(feature = "claude")]
Provider::Claude => Self {
provider,
display_name: "Claude".into(),
description: "Anthropic's helpful AI assistant".into(),
docs_url: Some("https://docs.anthropic.com".into()),
tos_url: Some("https://www.anthropic.com/legal/consumer-terms".into()),
selectors: HashMap::new(),
},
#[cfg(feature = "gemini")]
Provider::Gemini => Self {
provider,
display_name: "Gemini".into(),
description: "Google's multimodal AI".into(),
docs_url: Some("https://ai.google.dev".into()),
tos_url: Some("https://policies.google.com/terms".into()),
selectors: HashMap::new(),
},
#[cfg(feature = "chatgpt")]
Provider::ChatGpt => Self {
provider,
display_name: "ChatGPT".into(),
description: "OpenAI's conversational AI".into(),
docs_url: Some("https://platform.openai.com/docs".into()),
tos_url: Some("https://openai.com/policies/terms-of-use".into()),
selectors: HashMap::new(),
},
#[cfg(feature = "perplexity")]
Provider::Perplexity => Self {
provider,
display_name: "Perplexity".into(),
description: "AI-powered search engine".into(),
docs_url: Some("https://docs.perplexity.ai".into()),
tos_url: Some("https://www.perplexity.ai/tos".into()),
selectors: HashMap::new(),
},
#[cfg(feature = "notebooklm")]
Provider::NotebookLm => Self {
provider,
display_name: "NotebookLM".into(),
description: "Google's AI research assistant".into(),
docs_url: Some("https://notebooklm.google.com".into()),
tos_url: Some("https://policies.google.com/terms".into()),
selectors: HashMap::new(),
},
#[cfg(feature = "kaggle")]
Provider::Kaggle => Self {
provider,
display_name: "Kaggle".into(),
description: "Dataset search and catalog".into(),
docs_url: Some("https://www.kaggle.com/docs".into()),
tos_url: Some("https://www.kaggle.com/terms".into()),
selectors: HashMap::new(),
},
}
}
}