use serde::{Deserialize, Serialize};
use crate::llm::HostedCapabilities;
use super::SessionInitError;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum WebSearchCapabilityMode {
Delegate,
#[default]
Disabled,
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct WebSearchCapabilityConfig {
pub mode: WebSearchCapabilityMode,
}
impl WebSearchCapabilityConfig {
#[must_use]
pub const fn new(mode: WebSearchCapabilityMode) -> Self {
Self { mode }
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct SessionCapabilitiesConfig {
pub web_search: WebSearchCapabilityConfig,
}
impl SessionCapabilitiesConfig {
#[must_use]
pub const fn with_web_search(web_search: WebSearchCapabilityConfig) -> Self {
Self { web_search }
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct ResolvedSessionCapabilities {
pub hosted: HostedCapabilities,
}
impl ResolvedSessionCapabilities {
pub fn resolve(
config: SessionCapabilitiesConfig,
provider_hosted: HostedCapabilities,
provider_id: &str,
) -> Result<Self, SessionInitError> {
let mut hosted = HostedCapabilities::default();
match config.web_search.mode {
WebSearchCapabilityMode::Delegate => {
if !provider_hosted.web_search {
return Err(SessionInitError::CapabilityUnsatisfied {
capability: "web_search",
provider: provider_id.to_string(),
});
}
hosted.web_search = true;
}
WebSearchCapabilityMode::Disabled => {}
}
Ok(Self { hosted })
}
}
#[cfg(test)]
mod tests;