1use async_trait::async_trait;
6use serde_json::Value;
7
8use prompty::interfaces::{InvokerError, Processor};
9use prompty::model::Prompty;
10use prompty::types::ToolCall;
11
12pub struct OpenAIProcessor;
14
15#[async_trait]
16impl Processor for OpenAIProcessor {
17 async fn process(&self, agent: &Prompty, response: Value) -> Result<Value, InvokerError> {
18 process_response(agent, &response)
19 }
20
21 fn process_stream(
22 &self,
23 inner: std::pin::Pin<Box<dyn futures::Stream<Item = Value> + Send>>,
24 ) -> Result<
25 std::pin::Pin<Box<dyn futures::Stream<Item = prompty::types::StreamChunk> + Send>>,
26 InvokerError,
27 > {
28 Ok(process_stream(inner))
29 }
30}
31
32pub fn process_response(agent: &Prompty, response: &Value) -> Result<Value, InvokerError> {
34 if response.get("object").and_then(Value::as_str) == Some("response") {
36 return process_responses_api(agent, response);
37 }
38
39 if let Some(choices) = response.get("choices").and_then(Value::as_array) {
41 return process_chat_completion(agent, choices);
42 }
43
44 if response.get("object").and_then(Value::as_str) == Some("list") {
46 if let Some(data) = response.get("data").and_then(Value::as_array) {
47 return process_embedding(data);
48 }
49 }
50
51 if let Some(data) = response.get("data").and_then(Value::as_array) {
53 if data.iter().any(|d| {
54 d.get("url").is_some_and(|v| !v.is_null())
55 || d.get("b64_json").is_some_and(|v| !v.is_null())
56 }) {
57 return process_image(data);
58 }
59 }
60
61 Ok(response.clone())
63}
64
65fn process_chat_completion(agent: &Prompty, choices: &[Value]) -> Result<Value, InvokerError> {
70 let first = choices
71 .first()
72 .ok_or_else(|| InvokerError::Process("Empty choices array".to_string().into()))?;
73
74 let message = first
75 .get("message")
76 .ok_or_else(|| InvokerError::Process("Missing message in choice".to_string().into()))?;
77
78 if let Some(tool_calls) = message.get("tool_calls").and_then(Value::as_array) {
80 if !tool_calls.is_empty() {
81 let calls: Vec<Value> = tool_calls
82 .iter()
83 .map(|tc| {
84 let func = tc.get("function").unwrap_or(tc);
85 serde_json::json!({
86 "id": tc.get("id").and_then(Value::as_str).unwrap_or(""),
87 "name": func.get("name").and_then(Value::as_str).unwrap_or(""),
88 "arguments": func.get("arguments").and_then(Value::as_str).unwrap_or("{}"),
89 })
90 })
91 .collect();
92 return Ok(Value::Array(calls));
93 }
94 }
95
96 let content = message.get("content");
98
99 if content.is_none() || content == Some(&Value::Null) {
101 if let Some(refusal) = message.get("refusal").and_then(Value::as_str) {
102 return Ok(Value::String(refusal.to_string()));
103 }
104 }
105
106 let content_str = content.and_then(Value::as_str).unwrap_or("");
107
108 if let Some(outputs) = agent.as_outputs() {
111 if !outputs.is_empty() {
112 if let Ok(parsed) = serde_json::from_str::<Value>(content_str) {
113 return Ok(parsed);
114 }
115 }
117 }
118
119 Ok(Value::String(content_str.to_string()))
120}
121
122fn process_responses_api(agent: &Prompty, response: &Value) -> Result<Value, InvokerError> {
127 if let Some(output) = response.get("output").and_then(Value::as_array) {
129 let tool_calls: Vec<Value> = output
130 .iter()
131 .filter(|item| item.get("type").and_then(Value::as_str) == Some("function_call"))
132 .map(|item| {
133 serde_json::json!({
134 "id": item.get("call_id").and_then(Value::as_str).unwrap_or(""),
135 "name": item.get("name").and_then(Value::as_str).unwrap_or(""),
136 "arguments": item.get("arguments").and_then(Value::as_str).unwrap_or("{}"),
137 })
138 })
139 .collect();
140
141 if !tool_calls.is_empty() {
142 return Ok(Value::Array(tool_calls));
143 }
144 }
145
146 let output_text = response
148 .get("output_text")
149 .and_then(Value::as_str)
150 .unwrap_or("");
151
152 if let Some(outputs) = agent.as_outputs() {
154 if !outputs.is_empty() {
155 if let Ok(parsed) = serde_json::from_str::<Value>(output_text) {
156 return Ok(parsed);
157 }
158 }
159 }
160
161 Ok(Value::String(output_text.to_string()))
162}
163
164fn process_embedding(data: &[Value]) -> Result<Value, InvokerError> {
169 let vectors: Vec<Value> = data
170 .iter()
171 .filter_map(|d| d.get("embedding").cloned())
172 .collect();
173
174 if vectors.len() == 1 {
175 Ok(vectors.into_iter().next().unwrap())
176 } else {
177 Ok(Value::Array(vectors))
178 }
179}
180
181fn process_image(data: &[Value]) -> Result<Value, InvokerError> {
186 let urls: Vec<Value> = data
187 .iter()
188 .map(|d| {
189 let url = d.get("url").filter(|v| !v.is_null());
191 let b64 = d.get("b64_json").filter(|v| !v.is_null());
192 url.or(b64).cloned().unwrap_or(Value::Null)
193 })
194 .collect();
195
196 if urls.len() == 1 {
197 Ok(urls.into_iter().next().unwrap())
198 } else {
199 Ok(Value::Array(urls))
200 }
201}
202
203pub fn extract_tool_calls(response: &Value) -> Option<Vec<ToolCall>> {
209 let arr = response.as_array()?;
210 let calls: Vec<ToolCall> = arr
211 .iter()
212 .filter_map(|v| {
213 let id = v.get("id")?.as_str()?.to_string();
214 let name = v.get("name")?.as_str()?.to_string();
215 let arguments = v.get("arguments")?.as_str()?.to_string();
216 Some(ToolCall {
217 id,
218 name,
219 arguments,
220 })
221 })
222 .collect();
223 if calls.is_empty() { None } else { Some(calls) }
224}
225
226use prompty::types::StreamChunk;
231
232pub fn process_stream(
242 inner: impl futures::Stream<Item = Value> + Send + Unpin + 'static,
243) -> std::pin::Pin<Box<dyn futures::Stream<Item = StreamChunk> + Send>> {
244 Box::pin(OpenAIStreamProcessor::new(inner))
245}
246
247struct OpenAIStreamProcessor {
249 inner: std::pin::Pin<Box<dyn futures::Stream<Item = Value> + Send>>,
250 tool_call_acc: std::collections::BTreeMap<usize, (String, String, String)>,
252 phase: StreamPhase,
254 pending: std::collections::VecDeque<StreamChunk>,
256}
257
258enum StreamPhase {
259 Streaming,
260 YieldingTools(Vec<ToolCall>, usize),
262 Done,
263}
264
265impl OpenAIStreamProcessor {
266 fn new(inner: impl futures::Stream<Item = Value> + Send + Unpin + 'static) -> Self {
267 Self {
268 inner: Box::pin(inner),
269 tool_call_acc: std::collections::BTreeMap::new(),
270 phase: StreamPhase::Streaming,
271 pending: std::collections::VecDeque::new(),
272 }
273 }
274}
275
276impl futures::Stream for OpenAIStreamProcessor {
277 type Item = StreamChunk;
278
279 fn poll_next(
280 self: std::pin::Pin<&mut Self>,
281 cx: &mut std::task::Context<'_>,
282 ) -> std::task::Poll<Option<Self::Item>> {
283 let this = self.get_mut();
284
285 if let Some(chunk) = this.pending.pop_front() {
287 return std::task::Poll::Ready(Some(chunk));
288 }
289
290 match &mut this.phase {
291 StreamPhase::Streaming => {
292 match this.inner.as_mut().poll_next(cx) {
293 std::task::Poll::Ready(Some(chunk)) => {
294 let delta = chunk
295 .get("choices")
296 .and_then(Value::as_array)
297 .and_then(|c| c.first())
298 .and_then(|c| c.get("delta"));
299
300 if let Some(delta) = delta {
301 if let Some(content) = delta.get("content").and_then(Value::as_str) {
303 if !content.is_empty() {
304 return std::task::Poll::Ready(Some(StreamChunk::Text(
305 content.to_string(),
306 )));
307 }
308 }
309
310 if let Some(tc_deltas) =
312 delta.get("tool_calls").and_then(Value::as_array)
313 {
314 for tc_delta in tc_deltas {
315 let idx =
316 tc_delta.get("index").and_then(Value::as_u64).unwrap_or(0)
317 as usize;
318 let entry =
319 this.tool_call_acc.entry(idx).or_insert_with(|| {
320 (String::new(), String::new(), String::new())
321 });
322 if let Some(id) = tc_delta.get("id").and_then(Value::as_str) {
323 entry.0 = id.to_string();
324 }
325 if let Some(name) =
326 tc_delta.pointer("/function/name").and_then(Value::as_str)
327 {
328 entry.1 = name.to_string();
329 }
330 if let Some(args) = tc_delta
331 .pointer("/function/arguments")
332 .and_then(Value::as_str)
333 {
334 entry.2.push_str(args);
335 }
336 }
337 }
338
339 if let Some(refusal) = delta.get("refusal").and_then(Value::as_str) {
341 if !refusal.is_empty() {
342 this.phase = StreamPhase::Done;
343 return std::task::Poll::Ready(Some(StreamChunk::Error(
344 format!("Model refused: {refusal}"),
345 )));
346 }
347 }
348 }
349
350 cx.waker().wake_by_ref();
352 std::task::Poll::Pending
353 }
354 std::task::Poll::Ready(None) => {
355 let tools: Vec<ToolCall> = this
357 .tool_call_acc
358 .values()
359 .map(|(id, name, args)| ToolCall {
360 id: id.clone(),
361 name: name.clone(),
362 arguments: args.clone(),
363 })
364 .collect();
365
366 if tools.is_empty() {
367 this.phase = StreamPhase::Done;
368 std::task::Poll::Ready(None)
369 } else {
370 let first = tools[0].clone();
371 this.phase = StreamPhase::YieldingTools(tools, 1);
372 std::task::Poll::Ready(Some(StreamChunk::Tool(first)))
373 }
374 }
375 std::task::Poll::Pending => std::task::Poll::Pending,
376 }
377 }
378 StreamPhase::YieldingTools(tools, idx) if *idx < tools.len() => {
379 let tc = tools[*idx].clone();
380 *idx += 1;
381 std::task::Poll::Ready(Some(StreamChunk::Tool(tc)))
382 }
383 StreamPhase::YieldingTools(..) => {
384 this.phase = StreamPhase::Done;
385 std::task::Poll::Ready(None)
386 }
387 StreamPhase::Done => std::task::Poll::Ready(None),
388 }
389 }
390}
391
392#[cfg(test)]
397mod tests {
398 use super::*;
399 use prompty::model::context::LoadContext;
400 use serde_json::json;
401
402 fn make_agent(outputs_json: Value) -> Prompty {
403 let mut data = json!({
404 "name": "test",
405 "kind": "prompt",
406 "model": {"id": "gpt-4"},
407 "instructions": "test",
408 });
409 if !outputs_json.is_null() {
410 data["outputs"] = outputs_json;
411 }
412 Prompty::load_from_value(&data, &LoadContext::default())
413 }
414
415 #[test]
416 fn test_process_chat_content() {
417 let agent = make_agent(Value::Null);
418 let response = json!({
419 "choices": [{
420 "message": {
421 "role": "assistant",
422 "content": "Hello!"
423 }
424 }]
425 });
426 let result = process_response(&agent, &response).unwrap();
427 assert_eq!(result, json!("Hello!"));
428 }
429
430 #[test]
431 fn test_process_chat_tool_calls() {
432 let agent = make_agent(Value::Null);
433 let response = json!({
434 "choices": [{
435 "message": {
436 "role": "assistant",
437 "content": null,
438 "tool_calls": [{
439 "id": "call_1",
440 "type": "function",
441 "function": {
442 "name": "get_weather",
443 "arguments": "{\"city\":\"SF\"}"
444 }
445 }]
446 }
447 }]
448 });
449 let result = process_response(&agent, &response).unwrap();
450 let calls = result.as_array().unwrap();
451 assert_eq!(calls.len(), 1);
452 assert_eq!(calls[0]["name"], "get_weather");
453 assert_eq!(calls[0]["id"], "call_1");
454 }
455
456 #[test]
457 fn test_process_chat_refusal() {
458 let agent = make_agent(Value::Null);
459 let response = json!({
460 "choices": [{
461 "message": {
462 "role": "assistant",
463 "content": null,
464 "refusal": "I can't do that"
465 }
466 }]
467 });
468 let result = process_response(&agent, &response).unwrap();
469 assert_eq!(result, json!("I can't do that"));
470 }
471
472 #[test]
473 fn test_process_structured_output() {
474 let agent = make_agent(json!([
475 {"name": "answer", "kind": "string", "required": true}
476 ]));
477 let response = json!({
478 "choices": [{
479 "message": {
480 "role": "assistant",
481 "content": "{\"answer\": \"42\"}"
482 }
483 }]
484 });
485 let result = process_response(&agent, &response).unwrap();
486 assert_eq!(result["answer"], "42");
487 }
488
489 #[test]
490 fn test_process_embedding_single() {
491 let agent = make_agent(Value::Null);
492 let response = json!({
493 "object": "list",
494 "data": [{
495 "object": "embedding",
496 "embedding": [0.1, 0.2, 0.3]
497 }]
498 });
499 let result = process_response(&agent, &response).unwrap();
500 assert_eq!(result, json!([0.1, 0.2, 0.3]));
501 }
502
503 #[test]
504 fn test_process_embedding_multiple() {
505 let agent = make_agent(Value::Null);
506 let response = json!({
507 "object": "list",
508 "data": [
509 {"object": "embedding", "embedding": [0.1, 0.2]},
510 {"object": "embedding", "embedding": [0.3, 0.4]}
511 ]
512 });
513 let result = process_response(&agent, &response).unwrap();
514 assert_eq!(result, json!([[0.1, 0.2], [0.3, 0.4]]));
515 }
516
517 #[test]
518 fn test_process_image_single() {
519 let agent = make_agent(Value::Null);
520 let response = json!({
521 "data": [{"url": "https://example.com/image.png"}]
522 });
523 let result = process_response(&agent, &response).unwrap();
524 assert_eq!(result, json!("https://example.com/image.png"));
525 }
526
527 #[test]
528 fn test_process_image_multiple() {
529 let agent = make_agent(Value::Null);
530 let response = json!({
531 "data": [
532 {"url": "https://a.png"},
533 {"url": "https://b.png"}
534 ]
535 });
536 let result = process_response(&agent, &response).unwrap();
537 assert_eq!(result, json!(["https://a.png", "https://b.png"]));
538 }
539
540 #[test]
541 fn test_extract_tool_calls() {
542 let val = json!([
543 {"id": "c1", "name": "fn1", "arguments": "{}"},
544 {"id": "c2", "name": "fn2", "arguments": "{\"x\":1}"}
545 ]);
546 let calls = extract_tool_calls(&val).unwrap();
547 assert_eq!(calls.len(), 2);
548 assert_eq!(calls[0].name, "fn1");
549 assert_eq!(calls[1].name, "fn2");
550 }
551
552 #[test]
553 fn test_extract_tool_calls_not_tool_response() {
554 assert!(extract_tool_calls(&json!("Hello")).is_none());
555 assert!(extract_tool_calls(&json!(42)).is_none());
556 }
557
558 #[test]
563 fn test_empty_choices_error() {
564 let agent = Prompty::default();
565 let response = json!({
566 "choices": []
567 });
568 let err = process_response(&agent, &response).unwrap_err();
569 assert!(err.to_string().contains("Empty choices"));
570 }
571
572 #[test]
573 fn test_missing_message_error() {
574 let agent = Prompty::default();
575 let response = json!({
576 "choices": [{"finish_reason": "stop"}]
577 });
578 let err = process_response(&agent, &response).unwrap_err();
579 assert!(err.to_string().contains("Missing message"));
580 }
581
582 #[test]
583 fn test_tool_calls_with_missing_fields() {
584 let agent = Prompty::default();
585 let response = json!({
587 "choices": [{
588 "message": {
589 "tool_calls": [
590 {
591 "id": "call_1",
592 "function": {"name": "test", "arguments": "{}"}
593 },
594 {
595 "id": "call_2"
597 }
598 ]
599 }
600 }]
601 });
602 let result = process_response(&agent, &response).unwrap();
603 let arr = result.as_array().unwrap();
604 assert_eq!(arr.len(), 2);
605 assert_eq!(arr[0]["name"], "test");
606 assert_eq!(arr[1]["name"], "");
608 }
609
610 #[test]
611 fn test_null_content_no_refusal() {
612 let agent = Prompty::default();
613 let response = json!({
614 "choices": [{
615 "message": {
616 "content": null
617 }
618 }]
619 });
620 let result = process_response(&agent, &response).unwrap();
621 assert_eq!(result, "");
622 }
623
624 #[test]
625 fn test_unknown_response_shape_passthrough() {
626 let agent = Prompty::default();
627 let response = json!({
628 "unexpected": "format",
629 "custom": 42
630 });
631 let result = process_response(&agent, &response).unwrap();
632 assert_eq!(result, response);
633 }
634
635 #[test]
636 fn test_extract_tool_calls_empty_array() {
637 assert!(extract_tool_calls(&json!([])).is_none());
639 }
640
641 #[test]
642 fn test_extract_tool_calls_array_with_non_tool_objects() {
643 let val = json!([{"foo": "bar"}, {"baz": 42}]);
645 assert!(extract_tool_calls(&val).is_none());
646 }
647
648 #[test]
649 fn test_structured_output_invalid_json_falls_back() {
650 let data = serde_json::json!({
652 "kind": "prompt",
653 "name": "structured",
654 "model": "gpt-4",
655 "outputs": [{"name": "result", "kind": "object"}],
656 "instructions": "Return JSON"
657 });
658 let agent = Prompty::load_from_value(&data, &LoadContext::default());
659 let response = json!({
660 "choices": [{
661 "message": {
662 "content": "this is not json"
663 }
664 }]
665 });
666 let result = process_response(&agent, &response).unwrap();
667 assert_eq!(result, "this is not json");
668 }
669
670 #[test]
671 fn test_embedding_multiple_vectors() {
672 let agent = Prompty::default();
673 let response = json!({
674 "object": "list",
675 "data": [
676 {"embedding": [0.1, 0.2]},
677 {"embedding": [0.3, 0.4]}
678 ]
679 });
680 let result = process_response(&agent, &response).unwrap();
681 let arr = result.as_array().unwrap();
682 assert_eq!(arr.len(), 2);
683 }
684
685 #[test]
686 fn test_image_multiple_urls() {
687 let agent = Prompty::default();
688 let response = json!({
689 "data": [
690 {"url": "https://a.com/1.png"},
691 {"url": "https://a.com/2.png"}
692 ]
693 });
694 let result = process_response(&agent, &response).unwrap();
695 let arr = result.as_array().unwrap();
696 assert_eq!(arr.len(), 2);
697 }
698
699 #[tokio::test]
704 async fn test_stream_text_content() {
705 use futures::StreamExt;
706 let chunks = vec![
707 json!({"choices": [{"delta": {"content": "Hello"}}]}),
708 json!({"choices": [{"delta": {"content": " world"}}]}),
709 json!({"choices": [{"delta": {}}]}), ];
711 let inner = futures::stream::iter(chunks);
712 let mut stream = process_stream(inner);
713 let mut texts = Vec::new();
714 while let Some(chunk) = stream.next().await {
715 match chunk {
716 StreamChunk::Text(t) => texts.push(t),
717 StreamChunk::Tool(_) => panic!("unexpected tool call"),
718 _ => {}
719 }
720 }
721 assert_eq!(texts.join(""), "Hello world");
722 }
723
724 #[tokio::test]
725 async fn test_stream_tool_calls() {
726 use futures::StreamExt;
727 let chunks = vec![
728 json!({"choices": [{"delta": {"tool_calls": [
729 {"index": 0, "id": "call_1", "function": {"name": "get_weather", "arguments": "{\"ci"}}
730 ]}}]}),
731 json!({"choices": [{"delta": {"tool_calls": [
732 {"index": 0, "function": {"arguments": "ty\":\"SF\"}"}}
733 ]}}]}),
734 ];
735 let inner = futures::stream::iter(chunks);
736 let mut stream = process_stream(inner);
737 let mut tools = Vec::new();
738 while let Some(chunk) = stream.next().await {
739 match chunk {
740 StreamChunk::Text(_) => {}
741 StreamChunk::Tool(tc) => tools.push(tc),
742 _ => {}
743 }
744 }
745 assert_eq!(tools.len(), 1);
746 assert_eq!(tools[0].id, "call_1");
747 assert_eq!(tools[0].name, "get_weather");
748 assert_eq!(tools[0].arguments, "{\"city\":\"SF\"}");
749 }
750
751 #[tokio::test]
752 async fn test_stream_refusal() {
753 use futures::StreamExt;
754 let chunks = vec![json!({"choices": [{"delta": {"refusal": "I cannot help with that"}}]})];
755 let inner = futures::stream::iter(chunks);
756 let mut stream = process_stream(inner);
757 let mut errors = Vec::new();
758 while let Some(chunk) = stream.next().await {
759 if let StreamChunk::Error(e) = chunk {
760 errors.push(e);
761 }
762 }
763 assert_eq!(errors.len(), 1);
764 assert!(errors[0].contains("refused"));
765 }
766
767 #[tokio::test]
768 async fn test_stream_with_consume() {
769 use prompty::types::consume_stream_chunks;
770 let chunks = vec![
771 json!({"choices": [{"delta": {"content": "Hello"}}]}),
772 json!({"choices": [{"delta": {"content": " "}}]}),
773 json!({"choices": [{"delta": {"content": "world"}}]}),
774 ];
775 let inner = futures::stream::iter(chunks);
776 let stream = process_stream(inner);
777 let (tool_calls, content) = consume_stream_chunks(stream, None).await;
778 assert!(tool_calls.is_empty());
779 assert_eq!(content, "Hello world");
780 }
781
782 #[tokio::test]
783 async fn test_stream_mixed_content_then_tools() {
784 use futures::StreamExt;
785 let chunks = vec![
787 json!({"choices": [{"delta": {"content": "Let me check..."}}]}),
788 json!({"choices": [{"delta": {"tool_calls": [
789 {"index": 0, "id": "c1", "function": {"name": "search", "arguments": "{}"}}
790 ]}}]}),
791 ];
792 let inner = futures::stream::iter(chunks);
793 let mut stream = process_stream(inner);
794 let mut texts = Vec::new();
795 let mut tools = Vec::new();
796 while let Some(chunk) = stream.next().await {
797 match chunk {
798 StreamChunk::Text(t) => texts.push(t),
799 StreamChunk::Tool(tc) => tools.push(tc),
800 _ => {}
801 }
802 }
803 assert_eq!(texts.join(""), "Let me check...");
804 assert_eq!(tools.len(), 1);
805 assert_eq!(tools[0].name, "search");
806 }
807}