1use std::{fmt, sync::Arc};
7
8use agent_sdk_core::{
9 AgentError, AgentErrorKind, ProviderAdapter, ProviderCapabilities, ProviderMessageRole,
10 ProviderRequest, ProviderResponse, ProviderStopReason, ProviderToolCall, ProviderUsage,
11 RetryClassification, ToolCallId, domain::ContentRef as ContentRefId,
12 tool_records::CanonicalToolName,
13};
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16
17#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
18pub struct OpenAiResponsesConfig {
21 pub provider_ref: String,
23 pub model: String,
25 pub endpoint_ref: String,
27 pub supports_streaming: bool,
29 pub max_input_tokens: Option<u32>,
31}
32
33impl OpenAiResponsesConfig {
34 pub fn new(provider_ref: impl Into<String>, model: impl Into<String>) -> Self {
36 Self {
37 provider_ref: provider_ref.into(),
38 model: model.into(),
39 endpoint_ref: "endpoint.host_configured.openai_compatible".to_string(),
40 supports_streaming: false,
41 max_input_tokens: None,
42 }
43 }
44
45 pub fn endpoint_ref(mut self, endpoint_ref: impl Into<String>) -> Self {
47 self.endpoint_ref = endpoint_ref.into();
48 self
49 }
50
51 pub fn supports_streaming(mut self, supports_streaming: bool) -> Self {
53 self.supports_streaming = supports_streaming;
54 self
55 }
56
57 pub fn max_input_tokens(mut self, max_input_tokens: u32) -> Self {
59 self.max_input_tokens = Some(max_input_tokens);
60 self
61 }
62}
63
64pub trait OpenAiResponsesTransport: Send + Sync {
68 fn complete(
70 &self,
71 request: OpenAiResponsesRequest,
72 ) -> Result<OpenAiResponsesResponse, AgentError>;
73}
74
75pub trait OpenAiToolArgumentSink: Send + Sync {
78 fn store_tool_arguments(
81 &self,
82 call_id: &str,
83 canonical_tool_name: &CanonicalToolName,
84 raw_arguments: &str,
85 ) -> Result<Option<ContentRefId>, AgentError>;
86}
87
88#[derive(Clone)]
89pub struct OpenAiCompatibleResponsesAdapter {
91 config: OpenAiResponsesConfig,
92 transport: Arc<dyn OpenAiResponsesTransport>,
93 argument_sink: Option<Arc<dyn OpenAiToolArgumentSink>>,
94}
95
96impl OpenAiCompatibleResponsesAdapter {
97 pub fn new(
99 config: OpenAiResponsesConfig,
100 transport: Arc<dyn OpenAiResponsesTransport>,
101 ) -> Self {
102 Self {
103 config,
104 transport,
105 argument_sink: None,
106 }
107 }
108
109 pub fn with_argument_sink(mut self, sink: Arc<dyn OpenAiToolArgumentSink>) -> Self {
111 self.argument_sink = Some(sink);
112 self
113 }
114
115 pub fn config(&self) -> &OpenAiResponsesConfig {
117 &self.config
118 }
119
120 fn map_response(
121 &self,
122 response: OpenAiResponsesResponse,
123 ) -> Result<ProviderResponse, AgentError> {
124 let usage = response.usage.clone().map(ProviderUsage::from);
125 let tool_calls = self.tool_calls_from_response(&response)?;
126 if !tool_calls.is_empty() {
127 let mut mapped = ProviderResponse::tool_use(tool_calls);
128 mapped.usage = usage;
129 return Ok(mapped);
130 }
131
132 Ok(ProviderResponse {
133 schema_version: ProviderResponse::SCHEMA_VERSION,
134 output_text: response.output_text(),
135 stop_reason: response.stop_reason_without_tools(),
136 tool_calls: Vec::new(),
137 usage,
138 })
139 }
140
141 fn tool_calls_from_response(
142 &self,
143 response: &OpenAiResponsesResponse,
144 ) -> Result<Vec<ProviderToolCall>, AgentError> {
145 let mut calls = Vec::new();
146 for item in &response.output {
147 if item.kind != "function_call" {
148 continue;
149 }
150 let call_id = item.call_id.as_deref().ok_or_else(|| {
151 provider_failure("OpenAI-compatible function_call item missing call_id")
152 })?;
153 let name = item.name.as_deref().ok_or_else(|| {
154 provider_failure("OpenAI-compatible function_call item missing name")
155 })?;
156 let canonical_tool_name = CanonicalToolName::new(name);
157 let mut call = ProviderToolCall::new(
158 ToolCallId::new(call_id),
159 canonical_tool_name.clone(),
160 format!("provider requested tool {name} with arguments stored as content refs"),
161 );
162 if let (Some(sink), Some(raw_arguments)) =
163 (self.argument_sink.as_ref(), item.arguments.as_deref())
164 {
165 if let Some(args_ref) =
166 sink.store_tool_arguments(call_id, &canonical_tool_name, raw_arguments)?
167 {
168 call = call.with_args_ref(args_ref);
169 }
170 }
171 calls.push(call);
172 }
173 Ok(calls)
174 }
175}
176
177impl ProviderAdapter for OpenAiCompatibleResponsesAdapter {
178 fn capabilities(&self) -> ProviderCapabilities {
179 let mut capabilities = ProviderCapabilities::text_only(self.config.provider_ref.clone());
180 capabilities.supports_streaming = self.config.supports_streaming;
181 capabilities.max_input_tokens = self.config.max_input_tokens;
182 capabilities
183 }
184
185 fn complete(&self, request: &ProviderRequest) -> Result<ProviderResponse, AgentError> {
186 let wire_request = OpenAiResponsesRequest::from_provider_request(&self.config, request);
187 let response = self.transport.complete(wire_request)?;
188 self.map_response(response)
189 }
190}
191
192#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
193pub struct OpenAiResponsesRequest {
195 pub model: String,
197 pub input: Vec<OpenAiInputMessage>,
199 #[serde(skip_serializing_if = "Option::is_none")]
200 pub text: Option<OpenAiTextFormatHint>,
202 pub endpoint_ref: String,
204}
205
206impl OpenAiResponsesRequest {
207 pub fn from_provider_request(
209 config: &OpenAiResponsesConfig,
210 request: &ProviderRequest,
211 ) -> Self {
212 Self {
213 model: config.model.clone(),
214 input: request
215 .messages
216 .iter()
217 .map(OpenAiInputMessage::from_provider_message)
218 .collect(),
219 text: request
220 .structured_output_hint
221 .as_ref()
222 .map(OpenAiTextFormatHint::from_provider_hint),
223 endpoint_ref: config.endpoint_ref.clone(),
224 }
225 }
226}
227
228impl fmt::Debug for OpenAiResponsesRequest {
229 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
230 formatter
231 .debug_struct("OpenAiResponsesRequest")
232 .field("model", &self.model)
233 .field("input_count", &self.input.len())
234 .field("input", &"<redacted>")
235 .field("text", &self.text)
236 .field("endpoint_ref", &self.endpoint_ref)
237 .finish()
238 }
239}
240
241#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
242pub struct OpenAiInputMessage {
244 pub role: String,
246 pub content: String,
248}
249
250impl OpenAiInputMessage {
251 fn from_provider_message(message: &agent_sdk_core::ProviderMessage) -> Self {
252 Self {
253 role: role_name(&message.role).to_string(),
254 content: message.content.clone(),
255 }
256 }
257}
258
259impl fmt::Debug for OpenAiInputMessage {
260 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
261 formatter
262 .debug_struct("OpenAiInputMessage")
263 .field("role", &self.role)
264 .field("content", &"<redacted>")
265 .field("content_chars", &self.content.chars().count())
266 .finish()
267 }
268}
269
270#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
271pub struct OpenAiTextFormatHint {
273 #[serde(rename = "type")]
274 pub kind: String,
276 pub name: String,
278 pub schema_version: String,
280 pub schema_fingerprint: String,
282 pub include_schema_ref: bool,
284 #[serde(default, skip_serializing_if = "Option::is_none")]
285 pub schema: Option<Value>,
288}
289
290impl OpenAiTextFormatHint {
291 fn from_provider_hint(hint: &agent_sdk_core::ProviderStructuredOutputHint) -> Self {
292 Self {
293 kind: "json_schema".to_string(),
294 name: hint.schema_id.as_str().to_string(),
295 schema_version: format!(
296 "{}.{}.{}",
297 hint.schema_version.major, hint.schema_version.minor, hint.schema_version.patch
298 ),
299 schema_fingerprint: hint.schema_fingerprint.as_str().to_string(),
300 include_schema_ref: hint.include_schema_ref,
301 schema: hint.redacted_schema.clone(),
302 }
303 }
304}
305
306impl fmt::Debug for OpenAiTextFormatHint {
307 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
308 formatter
309 .debug_struct("OpenAiTextFormatHint")
310 .field("kind", &self.kind)
311 .field("name", &self.name)
312 .field("schema_version", &self.schema_version)
313 .field("schema_fingerprint", &self.schema_fingerprint)
314 .field("include_schema_ref", &self.include_schema_ref)
315 .field("schema_present", &self.schema.is_some())
316 .finish()
317 }
318}
319
320#[derive(Clone, Default, Deserialize, Eq, PartialEq, Serialize)]
321pub struct OpenAiResponsesResponse {
323 #[serde(skip_serializing_if = "Option::is_none")]
324 pub id: Option<String>,
326 #[serde(skip_serializing_if = "Option::is_none")]
327 pub status: Option<String>,
329 #[serde(default, skip_serializing_if = "String::is_empty")]
330 pub output_text: String,
332 #[serde(default, skip_serializing_if = "Vec::is_empty")]
333 pub output: Vec<OpenAiWireOutputItem>,
335 #[serde(skip_serializing_if = "Option::is_none")]
336 pub usage: Option<OpenAiResponsesUsage>,
338}
339
340impl OpenAiResponsesResponse {
341 pub fn text(output_text: impl Into<String>) -> Self {
343 Self {
344 status: Some("completed".to_string()),
345 output_text: output_text.into(),
346 ..Self::default()
347 }
348 }
349
350 pub fn function_call(
352 call_id: impl Into<String>,
353 name: impl Into<String>,
354 arguments: impl Into<String>,
355 ) -> Self {
356 Self {
357 status: Some("completed".to_string()),
358 output: vec![OpenAiWireOutputItem::function_call(
359 call_id, name, arguments,
360 )],
361 ..Self::default()
362 }
363 }
364
365 fn output_text(&self) -> String {
366 if !self.output_text.is_empty() {
367 return self.output_text.clone();
368 }
369 self.output
370 .iter()
371 .filter(|item| item.kind == "message")
372 .flat_map(|item| item.content.iter())
373 .filter_map(|part| {
374 if part.kind == "output_text" {
375 part.text.clone()
376 } else {
377 None
378 }
379 })
380 .collect::<Vec<_>>()
381 .join("")
382 }
383
384 fn stop_reason_without_tools(&self) -> ProviderStopReason {
385 match self.status.as_deref().unwrap_or("completed") {
386 "completed" => ProviderStopReason::EndTurn,
387 "cancelled" => ProviderStopReason::Cancelled,
388 "incomplete" => ProviderStopReason::MaxTokens,
389 "failed" => ProviderStopReason::ProviderError,
390 _ => ProviderStopReason::Unknown,
391 }
392 }
393}
394
395impl fmt::Debug for OpenAiResponsesResponse {
396 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
397 formatter
398 .debug_struct("OpenAiResponsesResponse")
399 .field("id", &self.id)
400 .field("status", &self.status)
401 .field("output_text", &"<redacted>")
402 .field("output_text_chars", &self.output_text.chars().count())
403 .field("output_count", &self.output.len())
404 .field("output", &self.output)
405 .field("usage", &self.usage)
406 .finish()
407 }
408}
409
410#[derive(Clone, Default, Deserialize, Eq, PartialEq, Serialize)]
411pub struct OpenAiWireOutputItem {
413 #[serde(rename = "type")]
414 pub kind: String,
416 #[serde(default, skip_serializing_if = "Vec::is_empty")]
417 pub content: Vec<OpenAiContentPart>,
419 #[serde(skip_serializing_if = "Option::is_none")]
420 pub call_id: Option<String>,
422 #[serde(skip_serializing_if = "Option::is_none")]
423 pub name: Option<String>,
425 #[serde(skip_serializing_if = "Option::is_none")]
426 pub arguments: Option<String>,
428}
429
430impl OpenAiWireOutputItem {
431 pub fn function_call(
433 call_id: impl Into<String>,
434 name: impl Into<String>,
435 arguments: impl Into<String>,
436 ) -> Self {
437 Self {
438 kind: "function_call".to_string(),
439 call_id: Some(call_id.into()),
440 name: Some(name.into()),
441 arguments: Some(arguments.into()),
442 ..Self::default()
443 }
444 }
445
446 pub fn message(text: impl Into<String>) -> Self {
448 Self {
449 kind: "message".to_string(),
450 content: vec![OpenAiContentPart {
451 kind: "output_text".to_string(),
452 text: Some(text.into()),
453 }],
454 ..Self::default()
455 }
456 }
457}
458
459impl fmt::Debug for OpenAiWireOutputItem {
460 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
461 formatter
462 .debug_struct("OpenAiWireOutputItem")
463 .field("kind", &self.kind)
464 .field("content_count", &self.content.len())
465 .field("content", &self.content)
466 .field("call_id", &self.call_id)
467 .field("name", &self.name)
468 .field("arguments", &"<redacted>")
469 .field(
470 "arguments_chars",
471 &self.arguments.as_ref().map(|value| value.chars().count()),
472 )
473 .finish()
474 }
475}
476
477#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
478pub struct OpenAiContentPart {
480 #[serde(rename = "type")]
481 pub kind: String,
483 #[serde(skip_serializing_if = "Option::is_none")]
484 pub text: Option<String>,
486}
487
488impl fmt::Debug for OpenAiContentPart {
489 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
490 formatter
491 .debug_struct("OpenAiContentPart")
492 .field("kind", &self.kind)
493 .field("text", &"<redacted>")
494 .field(
495 "text_chars",
496 &self.text.as_ref().map(|value| value.chars().count()),
497 )
498 .finish()
499 }
500}
501
502#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
503pub struct OpenAiResponsesUsage {
505 pub input_tokens: Option<u32>,
507 pub output_tokens: Option<u32>,
509 pub total_tokens: Option<u32>,
511}
512
513impl From<OpenAiResponsesUsage> for ProviderUsage {
514 fn from(value: OpenAiResponsesUsage) -> Self {
515 Self {
516 input_tokens: value.input_tokens,
517 output_tokens: value.output_tokens,
518 total_tokens: value.total_tokens,
519 }
520 }
521}
522
523fn role_name(role: &ProviderMessageRole) -> &'static str {
524 match role {
525 ProviderMessageRole::System => "system",
526 ProviderMessageRole::Developer => "developer",
527 ProviderMessageRole::User => "user",
528 ProviderMessageRole::Assistant => "assistant",
529 ProviderMessageRole::Tool => "tool",
530 ProviderMessageRole::Context => "user",
531 }
532}
533
534fn provider_failure(message: impl Into<String>) -> AgentError {
535 AgentError::new(
536 AgentErrorKind::ProviderFailure,
537 RetryClassification::RepairNeeded,
538 message,
539 )
540}