use std::{fmt, sync::Arc};
use agent_sdk_core::{
AgentError, AgentErrorKind, ProviderAdapter, ProviderCapabilities, ProviderMessageRole,
ProviderRequest, ProviderResponse, ProviderStopReason, ProviderToolCall, ProviderUsage,
RetryClassification, ToolCallId, domain::ContentRef as ContentRefId,
tool_records::CanonicalToolName,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct OpenAiResponsesConfig {
pub provider_ref: String,
pub model: String,
pub endpoint_ref: String,
pub supports_streaming: bool,
pub max_input_tokens: Option<u32>,
}
impl OpenAiResponsesConfig {
pub fn new(provider_ref: impl Into<String>, model: impl Into<String>) -> Self {
Self {
provider_ref: provider_ref.into(),
model: model.into(),
endpoint_ref: "endpoint.host_configured.openai_compatible".to_string(),
supports_streaming: false,
max_input_tokens: None,
}
}
pub fn endpoint_ref(mut self, endpoint_ref: impl Into<String>) -> Self {
self.endpoint_ref = endpoint_ref.into();
self
}
pub fn supports_streaming(mut self, supports_streaming: bool) -> Self {
self.supports_streaming = supports_streaming;
self
}
pub fn max_input_tokens(mut self, max_input_tokens: u32) -> Self {
self.max_input_tokens = Some(max_input_tokens);
self
}
}
pub trait OpenAiResponsesTransport: Send + Sync {
fn complete(
&self,
request: OpenAiResponsesRequest,
) -> Result<OpenAiResponsesResponse, AgentError>;
}
pub trait OpenAiToolArgumentSink: Send + Sync {
fn store_tool_arguments(
&self,
call_id: &str,
canonical_tool_name: &CanonicalToolName,
raw_arguments: &str,
) -> Result<Option<ContentRefId>, AgentError>;
}
#[derive(Clone)]
pub struct OpenAiCompatibleResponsesAdapter {
config: OpenAiResponsesConfig,
transport: Arc<dyn OpenAiResponsesTransport>,
argument_sink: Option<Arc<dyn OpenAiToolArgumentSink>>,
}
impl OpenAiCompatibleResponsesAdapter {
pub fn new(
config: OpenAiResponsesConfig,
transport: Arc<dyn OpenAiResponsesTransport>,
) -> Self {
Self {
config,
transport,
argument_sink: None,
}
}
pub fn with_argument_sink(mut self, sink: Arc<dyn OpenAiToolArgumentSink>) -> Self {
self.argument_sink = Some(sink);
self
}
pub fn config(&self) -> &OpenAiResponsesConfig {
&self.config
}
fn map_response(
&self,
response: OpenAiResponsesResponse,
) -> Result<ProviderResponse, AgentError> {
let usage = response.usage.clone().map(ProviderUsage::from);
let tool_calls = self.tool_calls_from_response(&response)?;
if !tool_calls.is_empty() {
let mut mapped = ProviderResponse::tool_use(tool_calls);
mapped.usage = usage;
return Ok(mapped);
}
Ok(ProviderResponse {
schema_version: ProviderResponse::SCHEMA_VERSION,
output_text: response.output_text(),
stop_reason: response.stop_reason_without_tools(),
tool_calls: Vec::new(),
usage,
})
}
fn tool_calls_from_response(
&self,
response: &OpenAiResponsesResponse,
) -> Result<Vec<ProviderToolCall>, AgentError> {
let mut calls = Vec::new();
for item in &response.output {
if item.kind != "function_call" {
continue;
}
let call_id = item.call_id.as_deref().ok_or_else(|| {
provider_failure("OpenAI-compatible function_call item missing call_id")
})?;
let name = item.name.as_deref().ok_or_else(|| {
provider_failure("OpenAI-compatible function_call item missing name")
})?;
let canonical_tool_name = CanonicalToolName::new(name);
let mut call = ProviderToolCall::new(
ToolCallId::new(call_id),
canonical_tool_name.clone(),
format!("provider requested tool {name} with arguments stored as content refs"),
);
if let (Some(sink), Some(raw_arguments)) =
(self.argument_sink.as_ref(), item.arguments.as_deref())
{
if let Some(args_ref) =
sink.store_tool_arguments(call_id, &canonical_tool_name, raw_arguments)?
{
call = call.with_args_ref(args_ref);
}
}
calls.push(call);
}
Ok(calls)
}
}
impl ProviderAdapter for OpenAiCompatibleResponsesAdapter {
fn capabilities(&self) -> ProviderCapabilities {
let mut capabilities = ProviderCapabilities::text_only(self.config.provider_ref.clone());
capabilities.supports_streaming = self.config.supports_streaming;
capabilities.max_input_tokens = self.config.max_input_tokens;
capabilities
}
fn complete(&self, request: &ProviderRequest) -> Result<ProviderResponse, AgentError> {
let wire_request = OpenAiResponsesRequest::from_provider_request(&self.config, request);
let response = self.transport.complete(wire_request)?;
self.map_response(response)
}
}
#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
pub struct OpenAiResponsesRequest {
pub model: String,
pub input: Vec<OpenAiInputMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<OpenAiTextFormatHint>,
pub endpoint_ref: String,
}
impl OpenAiResponsesRequest {
pub fn from_provider_request(
config: &OpenAiResponsesConfig,
request: &ProviderRequest,
) -> Self {
Self {
model: config.model.clone(),
input: request
.messages
.iter()
.map(OpenAiInputMessage::from_provider_message)
.collect(),
text: request
.structured_output_hint
.as_ref()
.map(OpenAiTextFormatHint::from_provider_hint),
endpoint_ref: config.endpoint_ref.clone(),
}
}
}
impl fmt::Debug for OpenAiResponsesRequest {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("OpenAiResponsesRequest")
.field("model", &self.model)
.field("input_count", &self.input.len())
.field("input", &"<redacted>")
.field("text", &self.text)
.field("endpoint_ref", &self.endpoint_ref)
.finish()
}
}
#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
pub struct OpenAiInputMessage {
pub role: String,
pub content: String,
}
impl OpenAiInputMessage {
fn from_provider_message(message: &agent_sdk_core::ProviderMessage) -> Self {
Self {
role: role_name(&message.role).to_string(),
content: message.content.clone(),
}
}
}
impl fmt::Debug for OpenAiInputMessage {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("OpenAiInputMessage")
.field("role", &self.role)
.field("content", &"<redacted>")
.field("content_chars", &self.content.chars().count())
.finish()
}
}
#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
pub struct OpenAiTextFormatHint {
#[serde(rename = "type")]
pub kind: String,
pub name: String,
pub schema_version: String,
pub schema_fingerprint: String,
pub include_schema_ref: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub schema: Option<Value>,
}
impl OpenAiTextFormatHint {
fn from_provider_hint(hint: &agent_sdk_core::ProviderStructuredOutputHint) -> Self {
Self {
kind: "json_schema".to_string(),
name: hint.schema_id.as_str().to_string(),
schema_version: format!(
"{}.{}.{}",
hint.schema_version.major, hint.schema_version.minor, hint.schema_version.patch
),
schema_fingerprint: hint.schema_fingerprint.as_str().to_string(),
include_schema_ref: hint.include_schema_ref,
schema: hint.redacted_schema.clone(),
}
}
}
impl fmt::Debug for OpenAiTextFormatHint {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("OpenAiTextFormatHint")
.field("kind", &self.kind)
.field("name", &self.name)
.field("schema_version", &self.schema_version)
.field("schema_fingerprint", &self.schema_fingerprint)
.field("include_schema_ref", &self.include_schema_ref)
.field("schema_present", &self.schema.is_some())
.finish()
}
}
#[derive(Clone, Default, Deserialize, Eq, PartialEq, Serialize)]
pub struct OpenAiResponsesResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub status: Option<String>,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub output_text: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub output: Vec<OpenAiWireOutputItem>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<OpenAiResponsesUsage>,
}
impl OpenAiResponsesResponse {
pub fn text(output_text: impl Into<String>) -> Self {
Self {
status: Some("completed".to_string()),
output_text: output_text.into(),
..Self::default()
}
}
pub fn function_call(
call_id: impl Into<String>,
name: impl Into<String>,
arguments: impl Into<String>,
) -> Self {
Self {
status: Some("completed".to_string()),
output: vec![OpenAiWireOutputItem::function_call(
call_id, name, arguments,
)],
..Self::default()
}
}
fn output_text(&self) -> String {
if !self.output_text.is_empty() {
return self.output_text.clone();
}
self.output
.iter()
.filter(|item| item.kind == "message")
.flat_map(|item| item.content.iter())
.filter_map(|part| {
if part.kind == "output_text" {
part.text.clone()
} else {
None
}
})
.collect::<Vec<_>>()
.join("")
}
fn stop_reason_without_tools(&self) -> ProviderStopReason {
match self.status.as_deref().unwrap_or("completed") {
"completed" => ProviderStopReason::EndTurn,
"cancelled" => ProviderStopReason::Cancelled,
"incomplete" => ProviderStopReason::MaxTokens,
"failed" => ProviderStopReason::ProviderError,
_ => ProviderStopReason::Unknown,
}
}
}
impl fmt::Debug for OpenAiResponsesResponse {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("OpenAiResponsesResponse")
.field("id", &self.id)
.field("status", &self.status)
.field("output_text", &"<redacted>")
.field("output_text_chars", &self.output_text.chars().count())
.field("output_count", &self.output.len())
.field("output", &self.output)
.field("usage", &self.usage)
.finish()
}
}
#[derive(Clone, Default, Deserialize, Eq, PartialEq, Serialize)]
pub struct OpenAiWireOutputItem {
#[serde(rename = "type")]
pub kind: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub content: Vec<OpenAiContentPart>,
#[serde(skip_serializing_if = "Option::is_none")]
pub call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
impl OpenAiWireOutputItem {
pub fn function_call(
call_id: impl Into<String>,
name: impl Into<String>,
arguments: impl Into<String>,
) -> Self {
Self {
kind: "function_call".to_string(),
call_id: Some(call_id.into()),
name: Some(name.into()),
arguments: Some(arguments.into()),
..Self::default()
}
}
pub fn message(text: impl Into<String>) -> Self {
Self {
kind: "message".to_string(),
content: vec![OpenAiContentPart {
kind: "output_text".to_string(),
text: Some(text.into()),
}],
..Self::default()
}
}
}
impl fmt::Debug for OpenAiWireOutputItem {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("OpenAiWireOutputItem")
.field("kind", &self.kind)
.field("content_count", &self.content.len())
.field("content", &self.content)
.field("call_id", &self.call_id)
.field("name", &self.name)
.field("arguments", &"<redacted>")
.field(
"arguments_chars",
&self.arguments.as_ref().map(|value| value.chars().count()),
)
.finish()
}
}
#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
pub struct OpenAiContentPart {
#[serde(rename = "type")]
pub kind: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
}
impl fmt::Debug for OpenAiContentPart {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("OpenAiContentPart")
.field("kind", &self.kind)
.field("text", &"<redacted>")
.field(
"text_chars",
&self.text.as_ref().map(|value| value.chars().count()),
)
.finish()
}
}
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
pub struct OpenAiResponsesUsage {
pub input_tokens: Option<u32>,
pub output_tokens: Option<u32>,
pub total_tokens: Option<u32>,
}
impl From<OpenAiResponsesUsage> for ProviderUsage {
fn from(value: OpenAiResponsesUsage) -> Self {
Self {
input_tokens: value.input_tokens,
output_tokens: value.output_tokens,
total_tokens: value.total_tokens,
}
}
}
fn role_name(role: &ProviderMessageRole) -> &'static str {
match role {
ProviderMessageRole::System => "system",
ProviderMessageRole::Developer => "developer",
ProviderMessageRole::User => "user",
ProviderMessageRole::Assistant => "assistant",
ProviderMessageRole::Tool => "tool",
ProviderMessageRole::Context => "user",
}
}
fn provider_failure(message: impl Into<String>) -> AgentError {
AgentError::new(
AgentErrorKind::ProviderFailure,
RetryClassification::RepairNeeded,
message,
)
}