use std::collections::HashMap;
use std::sync::Arc;
use chrono::{DateTime, Utc};
use tokio::sync::RwLock;
use crate::config::Config;
use crate::credentials::CredentialStore;
use crate::error::{Error, Result};
use crate::providers::{Provider, ProviderTrait};
use crate::ratelimit::RateLimiter;
use crate::security::{ContentScreener, ScreeningConfig, ScreeningResult};
use crate::session::Session;
#[cfg(feature = "chatgpt")]
use crate::providers::ChatGptProvider;
#[cfg(feature = "claude")]
use crate::providers::ClaudeProvider;
#[cfg(feature = "gemini")]
use crate::providers::GeminiProvider;
#[cfg(feature = "grok")]
use crate::providers::GrokProvider;
#[cfg(feature = "kaggle")]
use crate::providers::KaggleProvider;
#[cfg(feature = "notebooklm")]
use crate::providers::NotebookLmProvider;
#[cfg(feature = "perplexity")]
use crate::providers::PerplexityProvider;
#[derive(Debug, Clone, Default)]
pub struct PromptRequest {
pub message: String,
pub context: Option<String>,
pub conversation_id: Option<String>,
pub attachments: Vec<Attachment>,
pub metadata: HashMap<String, String>,
}
impl PromptRequest {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
..Default::default()
}
}
pub fn with_context(mut self, context: impl Into<String>) -> Self {
self.context = Some(context.into());
self
}
pub fn with_conversation(mut self, id: impl Into<String>) -> Self {
self.conversation_id = Some(id.into());
self
}
pub fn with_attachment(mut self, attachment: Attachment) -> Self {
self.attachments.push(attachment);
self
}
}
#[derive(Debug, Clone)]
pub struct Attachment {
pub name: String,
pub mime_type: String,
pub data: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct PromptResponse {
pub text: String,
pub provider: Provider,
pub conversation_id: Option<String>,
pub timestamp: DateTime<Utc>,
pub tokens_used: Option<u32>,
pub metadata: HashMap<String, String>,
}
pub struct WebPuppet {
config: Config,
credentials: Arc<CredentialStore>,
sessions: Arc<RwLock<HashMap<Provider, Session>>>,
providers: HashMap<Provider, Arc<dyn ProviderTrait>>,
rate_limiter: Arc<RateLimiter>,
screener: Arc<ContentScreener>,
}
impl WebPuppet {
pub fn builder() -> WebPuppetBuilder {
WebPuppetBuilder::default()
}
pub async fn new() -> Result<Self> {
Self::builder().build().await
}
pub fn providers(&self) -> Vec<Provider> {
self.providers.keys().copied().collect()
}
pub fn provider_capabilities(
&self,
provider: Provider,
) -> Option<crate::providers::ProviderCapabilities> {
self.providers.get(&provider).map(|p| p.capabilities())
}
pub fn has_provider(&self, provider: Provider) -> bool {
self.providers.contains_key(&provider)
}
pub async fn get_session(&self, provider: Provider) -> Result<Session> {
let sessions = self.sessions.read().await;
if let Some(session) = sessions.get(&provider) {
return Ok(session.clone());
}
drop(sessions);
let session = Session::new(&self.config, provider, self.credentials.clone()).await?;
let mut sessions = self.sessions.write().await;
sessions.insert(provider, session.clone());
Ok(session)
}
pub async fn authenticate(&self, provider: Provider) -> Result<()> {
let provider_impl = self
.providers
.get(&provider)
.ok_or_else(|| Error::UnsupportedProvider(provider.to_string()))?;
let mut session = self.get_session(provider).await?;
if !provider_impl.is_authenticated(&session).await? {
provider_impl.authenticate(&mut session).await?;
}
Ok(())
}
pub async fn prompt(
&self,
provider: Provider,
request: PromptRequest,
) -> Result<PromptResponse> {
let provider_impl = self
.providers
.get(&provider)
.ok_or_else(|| Error::UnsupportedProvider(provider.to_string()))?;
self.rate_limiter.wait(provider).await;
let session = self.get_session(provider).await?;
if !provider_impl.is_authenticated(&session).await? {
return Err(Error::SessionExpired(provider.to_string()));
}
if let Some(delay) = provider_impl.check_rate_limit(&session).await? {
tracing::warn!("Rate limited by {}, waiting {:?}", provider, delay);
tokio::time::sleep(delay).await;
}
let response = if let Some(ref conv_id) = request.conversation_id {
provider_impl
.continue_conversation(&session, conv_id, &request)
.await?
} else {
provider_impl.send_prompt(&session, &request).await?
};
Ok(response)
}
pub async fn prompt_screened(
&self,
provider: Provider,
request: PromptRequest,
) -> Result<(PromptResponse, ScreeningResult)> {
let mut response = self.prompt(provider, request).await?;
let screening = self.screener.screen(&response.text);
if !screening.passed {
tracing::warn!(
"Response from {} flagged with risk score {:.2}: {:?}",
provider,
screening.risk_score,
screening
.issues
.iter()
.map(|i| format!("{:?}", i))
.collect::<Vec<_>>()
);
}
response.text = screening.sanitized.clone();
Ok((response, screening))
}
pub async fn prompt_any_screened(
&self,
request: PromptRequest,
) -> Result<(PromptResponse, ScreeningResult)> {
let providers = self.providers();
if providers.is_empty() {
return Err(Error::Config("No providers configured".into()));
}
let mut last_error = None;
for provider in providers {
match self.prompt_screened(provider, request.clone()).await {
Ok(result) => return Ok(result),
Err(e) => {
tracing::warn!("Provider {} failed: {}", provider, e);
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or_else(|| Error::Config("All providers failed".into())))
}
pub fn screener(&self) -> &ContentScreener {
&self.screener
}
pub async fn prompt_any(&self, request: PromptRequest) -> Result<PromptResponse> {
let providers = self.providers();
if providers.is_empty() {
return Err(Error::Config("No providers configured".into()));
}
let mut last_error = None;
for provider in providers {
match self.prompt(provider, request.clone()).await {
Ok(response) => return Ok(response),
Err(e) => {
tracing::warn!("Provider {} failed: {}", provider, e);
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or_else(|| Error::Config("All providers failed".into())))
}
pub async fn new_conversation(&self, provider: Provider) -> Result<String> {
let provider_impl = self
.providers
.get(&provider)
.ok_or_else(|| Error::UnsupportedProvider(provider.to_string()))?;
let session = self.get_session(provider).await?;
provider_impl.new_conversation(&session).await
}
pub async fn close(&self) -> Result<()> {
let mut sessions = self.sessions.write().await;
for (_, session) in sessions.drain() {
session.close().await.ok();
}
Ok(())
}
}
#[derive(Default)]
pub struct WebPuppetBuilder {
config: Option<Config>,
screening_config: Option<ScreeningConfig>,
providers: Vec<Provider>,
headless: bool,
}
impl WebPuppetBuilder {
pub fn with_config(mut self, config: Config) -> Self {
self.config = Some(config);
self
}
pub fn with_provider(mut self, provider: Provider) -> Self {
if !self.providers.contains(&provider) {
self.providers.push(provider);
}
self
}
pub fn with_all_providers(mut self) -> Self {
self.providers = Provider::all();
self
}
pub fn headless(mut self, headless: bool) -> Self {
self.headless = headless;
self
}
pub fn with_screening_config(mut self, config: ScreeningConfig) -> Self {
self.screening_config = Some(config);
self
}
pub async fn build(self) -> Result<WebPuppet> {
let mut config = self.config.unwrap_or_default();
config.browser.headless = self.headless;
let credentials = Arc::new(CredentialStore::new()?);
let rate_limiter = Arc::new(RateLimiter::new(&config.rate_limit));
let mut providers: HashMap<Provider, Arc<dyn ProviderTrait>> = HashMap::new();
let enabled_providers = if self.providers.is_empty() {
Provider::all()
} else {
self.providers
};
for provider in enabled_providers {
match provider {
#[cfg(feature = "grok")]
Provider::Grok => {
providers.insert(provider, Arc::new(GrokProvider::new()));
}
#[cfg(feature = "claude")]
Provider::Claude => {
providers.insert(provider, Arc::new(ClaudeProvider::new()));
}
#[cfg(feature = "gemini")]
Provider::Gemini => {
providers.insert(provider, Arc::new(GeminiProvider::new()));
}
#[cfg(feature = "chatgpt")]
Provider::ChatGpt => {
providers.insert(provider, Arc::new(ChatGptProvider::new()));
}
#[cfg(feature = "perplexity")]
Provider::Perplexity => {
providers.insert(provider, Arc::new(PerplexityProvider::new()));
}
#[cfg(feature = "notebooklm")]
Provider::NotebookLm => {
providers.insert(provider, Arc::new(NotebookLmProvider::new()));
}
#[cfg(feature = "kaggle")]
Provider::Kaggle => {
providers.insert(provider, Arc::new(KaggleProvider::new()));
}
#[allow(unreachable_patterns)]
_ => {
tracing::debug!("Provider {:?} not enabled via features", provider);
}
}
}
let screener = Arc::new(
self.screening_config
.map(ContentScreener::with_config)
.unwrap_or_default(),
);
Ok(WebPuppet {
config,
credentials,
sessions: Arc::new(RwLock::new(HashMap::new())),
providers,
rate_limiter,
screener,
})
}
}
pub async fn quick_prompt(
provider: Provider,
message: impl Into<String>,
) -> Result<PromptResponse> {
let puppet = WebPuppet::builder()
.with_provider(provider)
.headless(true)
.build()
.await?;
puppet.authenticate(provider).await?;
let response = puppet.prompt(provider, PromptRequest::new(message)).await?;
puppet.close().await?;
Ok(response)
}