use core::fmt;
use serde::{Deserialize, Serialize};
use crate::{
context::ContextProjection,
domain::{
AgentError, AgentErrorKind, ContentRef, DestinationKind, OutputSchemaId, PrivacyClass,
RetryClassification, SourceKind, ToolCallId,
},
output::{ContentHash, OutputContract, OutputSchemaRef, ProviderHintPolicy, SchemaVersion},
projection::project_context_projection,
tool_records::CanonicalToolName,
};
pub trait ProviderAdapter: Send + Sync {
fn capabilities(&self) -> ProviderCapabilities;
fn project_request(
&self,
projection: &ContextProjection,
policy: &ProviderProjectionPolicy,
) -> Result<ProviderRequest, AgentError> {
project_context_projection(projection, policy)
}
fn complete(&self, request: &ProviderRequest) -> Result<ProviderResponse, AgentError>;
fn stream(&self, request: &ProviderRequest) -> Result<Vec<ProviderStreamChunk>, AgentError> {
let response = self.complete(request)?;
if response.tool_calls.is_empty() {
return Ok(vec![ProviderStreamChunk::final_text(
response.output_text.clone(),
response.stop_reason.clone(),
response.usage.clone(),
)]);
}
let mut chunks = Vec::new();
let mut chunk_index = 0;
if !response.output_text.is_empty() {
chunks.push(ProviderStreamChunk::text(
chunk_index,
response.output_text.clone(),
));
chunk_index += 1;
}
chunks.push(ProviderStreamChunk::final_tool_calls(
chunk_index,
response.tool_calls.clone(),
response.stop_reason.clone(),
response.usage.clone(),
));
Ok(chunks)
}
fn extract_usage(&self, response: &ProviderResponse) -> ProviderUsage {
response.usage.clone().unwrap_or_default()
}
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct ProviderCapabilities {
pub provider_ref: String,
pub supports_streaming: bool,
pub supports_usage: bool,
pub max_input_tokens: Option<u32>,
pub supported_modalities: Vec<ProviderModality>,
}
impl ProviderCapabilities {
pub fn text_only(provider_ref: impl Into<String>) -> Self {
Self {
provider_ref: provider_ref.into(),
supports_streaming: false,
supports_usage: true,
max_input_tokens: None,
supported_modalities: vec![ProviderModality::Text],
}
}
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ProviderModality {
Text,
Image,
Audio,
Video,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct ProviderProjectionPolicy {
pub allow_private_metadata_projection: bool,
pub projection_policy_ref: String,
}
impl ProviderProjectionPolicy {
pub fn redacted(policy_ref: impl Into<String>) -> Self {
Self {
allow_private_metadata_projection: false,
projection_policy_ref: policy_ref.into(),
}
}
pub fn allow_private_metadata(policy_ref: impl Into<String>) -> Self {
Self {
allow_private_metadata_projection: true,
projection_policy_ref: policy_ref.into(),
}
}
}
#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
pub struct ProviderRequest {
pub schema_version: u16,
pub projection_policy_ref: String,
pub messages: Vec<ProviderMessage>,
pub projection_item_count: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub structured_output_hint: Option<ProviderStructuredOutputHint>,
}
impl ProviderRequest {
pub const SCHEMA_VERSION: u16 = 1;
pub fn with_structured_output_hint(mut self, contract: &OutputContract) -> Self {
self.structured_output_hint = Some(ProviderStructuredOutputHint::from(contract));
self
}
}
impl fmt::Debug for ProviderRequest {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("ProviderRequest")
.field("schema_version", &self.schema_version)
.field("projection_policy_ref", &self.projection_policy_ref)
.field("message_count", &self.messages.len())
.field("messages", &"<redacted>")
.field("projection_item_count", &self.projection_item_count)
.field("structured_output_hint", &self.structured_output_hint)
.finish()
}
}
#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
pub struct ProviderStructuredOutputHint {
pub schema_id: OutputSchemaId,
pub schema_version: SchemaVersion,
pub schema_fingerprint: ContentHash,
pub provider_hint_policy: ProviderHintPolicy,
pub include_schema_ref: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub redacted_schema: Option<serde_json::Value>,
}
impl From<&OutputContract> for ProviderStructuredOutputHint {
fn from(contract: &OutputContract) -> Self {
let redacted_schema = match &contract.schema {
OutputSchemaRef::InlineJson {
redacted_schema, ..
} => Some(redacted_schema.clone()),
_ => None,
};
Self {
schema_id: contract.schema_id.clone(),
schema_version: contract.schema_version,
schema_fingerprint: contract.schema_fingerprint(),
provider_hint_policy: contract.projection_hint.provider_hint_policy.clone(),
include_schema_ref: contract.projection_hint.include_schema_ref,
redacted_schema,
}
}
}
impl fmt::Debug for ProviderStructuredOutputHint {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("ProviderStructuredOutputHint")
.field("schema_id", &self.schema_id)
.field("schema_version", &self.schema_version)
.field("schema_fingerprint", &self.schema_fingerprint)
.field("provider_hint_policy", &self.provider_hint_policy)
.field("include_schema_ref", &self.include_schema_ref)
.field("redacted_schema_present", &self.redacted_schema.is_some())
.finish()
}
}
#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
pub struct ProviderMessage {
pub role: ProviderMessageRole,
pub content: String,
pub privacy: PrivacyClass,
#[serde(skip_serializing_if = "Option::is_none")]
pub projected_metadata: Option<ProviderProjectedMetadata>,
}
impl fmt::Debug for ProviderMessage {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("ProviderMessage")
.field("role", &self.role)
.field("content", &"<redacted>")
.field("content_chars", &self.content.chars().count())
.field("privacy", &self.privacy)
.field("projected_metadata", &self.projected_metadata)
.finish()
}
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ProviderMessageRole {
System,
Developer,
User,
Assistant,
Tool,
Context,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct ProviderProjectedMetadata {
pub source_kind: SourceKind,
pub source_id: String,
pub destination_kind: DestinationKind,
pub destination_id: String,
pub subject_kind: String,
pub subject_id: String,
}
#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
pub struct ProviderResponse {
pub schema_version: u16,
pub output_text: String,
pub stop_reason: ProviderStopReason,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tool_calls: Vec<ProviderToolCall>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<ProviderUsage>,
}
impl ProviderResponse {
pub const SCHEMA_VERSION: u16 = 1;
pub fn text(output_text: impl Into<String>) -> Self {
Self {
schema_version: Self::SCHEMA_VERSION,
output_text: output_text.into(),
stop_reason: ProviderStopReason::EndTurn,
tool_calls: Vec::new(),
usage: None,
}
}
pub fn tool_use(tool_calls: impl IntoIterator<Item = ProviderToolCall>) -> Self {
Self {
schema_version: Self::SCHEMA_VERSION,
output_text: String::new(),
stop_reason: ProviderStopReason::ToolUse,
tool_calls: tool_calls.into_iter().collect(),
usage: None,
}
}
pub fn with_usage(mut self, usage: ProviderUsage) -> Self {
self.usage = Some(usage);
self
}
}
impl fmt::Debug for ProviderResponse {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("ProviderResponse")
.field("schema_version", &self.schema_version)
.field("output_text", &"<redacted>")
.field("output_text_chars", &self.output_text.chars().count())
.field("stop_reason", &self.stop_reason)
.field("tool_calls", &self.tool_calls)
.field("usage", &self.usage)
.finish()
}
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ProviderStopReason {
EndTurn,
MaxTokens,
ToolUse,
Cancelled,
ProviderError,
Unknown,
}
#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
pub struct ProviderStreamChunk {
pub schema_version: u16,
pub chunk_index: u32,
pub delta: ProviderStreamDelta,
pub is_terminal: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<ProviderUsage>,
}
impl ProviderStreamChunk {
pub const SCHEMA_VERSION: u16 = 1;
pub fn text(chunk_index: u32, text: impl Into<String>) -> Self {
Self {
schema_version: Self::SCHEMA_VERSION,
chunk_index,
delta: ProviderStreamDelta::Text {
text: text.into(),
stop_reason: None,
},
is_terminal: false,
usage: None,
}
}
pub fn final_text(
text: impl Into<String>,
stop_reason: ProviderStopReason,
usage: Option<ProviderUsage>,
) -> Self {
Self {
schema_version: Self::SCHEMA_VERSION,
chunk_index: 0,
delta: ProviderStreamDelta::Text {
text: text.into(),
stop_reason: Some(stop_reason),
},
is_terminal: true,
usage,
}
}
pub fn tool_calls(
chunk_index: u32,
tool_calls: impl IntoIterator<Item = ProviderToolCall>,
) -> Self {
Self {
schema_version: Self::SCHEMA_VERSION,
chunk_index,
delta: ProviderStreamDelta::ToolCalls {
tool_calls: tool_calls.into_iter().collect(),
stop_reason: None,
},
is_terminal: false,
usage: None,
}
}
pub fn final_tool_calls(
chunk_index: u32,
tool_calls: impl IntoIterator<Item = ProviderToolCall>,
stop_reason: ProviderStopReason,
usage: Option<ProviderUsage>,
) -> Self {
Self {
schema_version: Self::SCHEMA_VERSION,
chunk_index,
delta: ProviderStreamDelta::ToolCalls {
tool_calls: tool_calls.into_iter().collect(),
stop_reason: Some(stop_reason),
},
is_terminal: true,
usage,
}
}
}
impl fmt::Debug for ProviderStreamChunk {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("ProviderStreamChunk")
.field("schema_version", &self.schema_version)
.field("chunk_index", &self.chunk_index)
.field("delta", &self.delta)
.field("is_terminal", &self.is_terminal)
.field("usage", &self.usage)
.finish()
}
}
#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum ProviderStreamDelta {
Text {
text: String,
#[serde(skip_serializing_if = "Option::is_none")]
stop_reason: Option<ProviderStopReason>,
},
Usage {
usage: ProviderUsage,
},
ToolCalls {
tool_calls: Vec<ProviderToolCall>,
#[serde(skip_serializing_if = "Option::is_none")]
stop_reason: Option<ProviderStopReason>,
},
Error {
redacted_summary: String,
},
}
impl fmt::Debug for ProviderStreamDelta {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Text { text, stop_reason } => formatter
.debug_struct("Text")
.field("text", &"<redacted>")
.field("text_chars", &text.chars().count())
.field("stop_reason", stop_reason)
.finish(),
Self::Usage { usage } => formatter
.debug_struct("Usage")
.field("usage", usage)
.finish(),
Self::ToolCalls {
tool_calls,
stop_reason,
} => formatter
.debug_struct("ToolCalls")
.field("tool_calls", tool_calls)
.field("stop_reason", stop_reason)
.finish(),
Self::Error { redacted_summary } => formatter
.debug_struct("Error")
.field("redacted_summary", redacted_summary)
.finish(),
}
}
}
#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
pub struct ProviderToolCall {
pub tool_call_id: ToolCallId,
pub canonical_tool_name: CanonicalToolName,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub requested_args_refs: Vec<ContentRef>,
pub redacted_args_summary: String,
}
impl ProviderToolCall {
pub fn new(
tool_call_id: ToolCallId,
canonical_tool_name: CanonicalToolName,
redacted_args_summary: impl Into<String>,
) -> Self {
Self {
tool_call_id,
canonical_tool_name,
requested_args_refs: Vec::new(),
redacted_args_summary: redacted_args_summary.into(),
}
}
pub fn with_args_ref(mut self, args_ref: ContentRef) -> Self {
self.requested_args_refs.push(args_ref);
self
}
}
impl fmt::Debug for ProviderToolCall {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("ProviderToolCall")
.field("tool_call_id", &self.tool_call_id)
.field("canonical_tool_name", &self.canonical_tool_name)
.field("requested_args_refs", &self.requested_args_refs)
.field(
"redacted_args_summary_chars",
&self.redacted_args_summary.chars().count(),
)
.finish()
}
}
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
pub struct ProviderUsage {
pub input_tokens: Option<u32>,
pub output_tokens: Option<u32>,
pub total_tokens: Option<u32>,
}
#[derive(Clone, Debug)]
pub struct ProviderConformanceCase {
pub projection: ContextProjection,
pub policy: ProviderProjectionPolicy,
}
impl ProviderConformanceCase {
pub fn new(projection: ContextProjection) -> Self {
Self {
projection,
policy: ProviderProjectionPolicy::redacted("policy.provider.redacted"),
}
}
pub fn assert_adapter<A: ProviderAdapter>(
&self,
adapter: &A,
) -> Result<ProviderUsage, AgentError> {
let capabilities = adapter.capabilities();
if capabilities.provider_ref.is_empty() {
return Err(AgentError::new(
AgentErrorKind::ProviderFailure,
RetryClassification::HostConfigurationNeeded,
"provider capabilities must name a provider ref",
));
}
let request = adapter.project_request(&self.projection, &self.policy)?;
if request.projection_item_count != self.projection.items.len() {
return Err(AgentError::new(
AgentErrorKind::ProjectionFailure,
RetryClassification::RepairNeeded,
"provider request item count must match the context projection",
));
}
let response = adapter.complete(&request)?;
Ok(adapter.extract_usage(&response))
}
}