1use async_trait::async_trait;
7use serde_json::Value;
8use std::sync::LazyLock;
9
10use prompty::interfaces::{Executor, InvokerError};
11use prompty::model::Prompty;
12use prompty::types::Message;
13
14use crate::wire;
15
16static HTTP_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(reqwest::Client::new);
18
19pub struct OpenAIExecutor;
21
22#[async_trait]
23impl Executor for OpenAIExecutor {
24 async fn execute(&self, agent: &Prompty, messages: &[Message]) -> Result<Value, InvokerError> {
25 let api_type = agent
26 .model
27 .api_type
28 .as_ref()
29 .map(|t| t.as_str())
30 .unwrap_or("chat");
31
32 let (url, body) = match api_type {
33 "chat" | "agent" => {
34 let args = wire::build_chat_args(agent, messages);
35 let url = build_url(agent, "/v1/chat/completions")?;
36 (url, args)
37 }
38 "responses" => {
39 let args = wire::build_responses_args(agent, messages);
40 let url = build_url(agent, "/v1/responses")?;
41 (url, args)
42 }
43 "embedding" => {
44 let args = wire::build_embedding_args(agent, messages);
45 let url = build_url(agent, "/v1/embeddings")?;
46 (url, args)
47 }
48 "image" => {
49 let args = wire::build_image_args(agent, messages);
50 let url = build_url(agent, "/v1/images/generations")?;
51 (url, args)
52 }
53 other => {
54 return Err(InvokerError::Execute(
55 format!("Unsupported apiType: {other}").into(),
56 ));
57 }
58 };
59
60 let api_key = get_api_key(agent)?;
61 let client = &*HTTP_CLIENT;
62 let response = client
63 .post(&url)
64 .header("Authorization", format!("Bearer {api_key}"))
65 .header("Content-Type", "application/json")
66 .json(&body)
67 .send()
68 .await
69 .map_err(|e| InvokerError::Execute(format!("HTTP request failed: {e}").into()))?;
70
71 if !response.status().is_success() {
72 let status = response.status();
73 let body_text = response
74 .text()
75 .await
76 .unwrap_or_else(|_| "unable to read body".to_string());
77 return Err(InvokerError::Execute(
78 format!("OpenAI API error (HTTP {status}): {body_text}").into(),
79 ));
80 }
81
82 let result: Value = response
83 .json()
84 .await
85 .map_err(|e| InvokerError::Execute(format!("Failed to parse response: {e}").into()))?;
86
87 Ok(result)
88 }
89
90 fn format_tool_messages(
91 &self,
92 _raw_response: &serde_json::Value,
93 tool_calls: &[prompty::types::ToolCall],
94 tool_results: &[String],
95 _text_content: Option<&str>,
96 ) -> Vec<Message> {
97 wire::format_tool_messages(tool_calls, tool_results)
98 }
99
100 async fn execute_stream(
101 &self,
102 agent: &Prompty,
103 messages: &[Message],
104 ) -> Result<std::pin::Pin<Box<dyn futures::Stream<Item = Value> + Send>>, InvokerError> {
105 let api_type = agent
106 .model
107 .api_type
108 .as_ref()
109 .map(|t| t.as_str())
110 .unwrap_or("chat");
111
112 let (url, mut body) = match api_type {
113 "chat" | "agent" => {
114 let args = wire::build_chat_args(agent, messages);
115 let url = build_url(agent, "/v1/chat/completions")?;
116 (url, args)
117 }
118 "responses" => {
119 let args = wire::build_responses_args(agent, messages);
120 let url = build_url(agent, "/v1/responses")?;
121 (url, args)
122 }
123 other => {
124 return Err(InvokerError::Execute(
125 format!("Streaming not supported for apiType: {other}").into(),
126 ));
127 }
128 };
129
130 if let Some(obj) = body.as_object_mut() {
132 obj.insert("stream".into(), Value::Bool(true));
133 }
134
135 let api_key = get_api_key(agent)?;
136 let client = &*HTTP_CLIENT;
137 let response = client
138 .post(&url)
139 .header("Authorization", format!("Bearer {api_key}"))
140 .header("Content-Type", "application/json")
141 .json(&body)
142 .send()
143 .await
144 .map_err(|e| InvokerError::Execute(format!("HTTP request failed: {e}").into()))?;
145
146 if !response.status().is_success() {
147 let status = response.status();
148 let body_text = response
149 .text()
150 .await
151 .unwrap_or_else(|_| "unable to read body".to_string());
152 return Err(InvokerError::Execute(
153 format!("OpenAI API error (HTTP {status}): {body_text}").into(),
154 ));
155 }
156
157 let byte_stream = response.bytes_stream();
158 Ok(Box::pin(SseParser::new(byte_stream)))
159 }
160}
161
162impl OpenAIExecutor {
163 pub fn build_args(agent: &Prompty, messages: &[Message]) -> Result<Value, InvokerError> {
165 let api_type = agent
166 .model
167 .api_type
168 .as_ref()
169 .map(|t| t.as_str())
170 .unwrap_or("chat");
171 Ok(match api_type {
172 "chat" | "agent" => wire::build_chat_args(agent, messages),
173 "embedding" => wire::build_embedding_args(agent, messages),
174 "image" => wire::build_image_args(agent, messages),
175 other => {
176 return Err(InvokerError::Execute(
177 format!("Unsupported apiType: {other}").into(),
178 ));
179 }
180 })
181 }
182}
183
184fn resolve_connection(
191 agent: &Prompty,
192) -> Result<std::borrow::Cow<'_, serde_json::Value>, InvokerError> {
193 let conn = &agent.model.connection;
194 let kind = conn.get("kind").and_then(|k| k.as_str()).unwrap_or("");
195
196 if kind == "reference" {
197 let name = conn.get("name").and_then(|n| n.as_str()).ok_or_else(|| {
198 InvokerError::Execute(
199 "Reference connection missing 'name' field"
200 .to_string()
201 .into(),
202 )
203 })?;
204
205 let resolved =
207 prompty::connections::with_connection::<serde_json::Value, _>(name, |c| c.clone())
208 .map_err(|e| InvokerError::Execute(e.into()))?;
209
210 Ok(std::borrow::Cow::Owned(resolved))
211 } else {
212 Ok(std::borrow::Cow::Borrowed(conn))
213 }
214}
215
216fn build_url(agent: &Prompty, path: &str) -> Result<String, InvokerError> {
217 let conn = resolve_connection(agent)?;
218
219 let endpoint = conn
223 .get("endpoint")
224 .and_then(|e| e.as_str())
225 .filter(|s| !s.is_empty())
226 .map(String::from)
227 .or_else(|| {
228 std::env::var("OPENAI_BASE_URL")
229 .ok()
230 .filter(|s| !s.is_empty())
231 })
232 .unwrap_or_else(|| "https://api.openai.com".to_string());
233
234 let base = endpoint.trim_end_matches('/');
235
236 let adjusted_path = if base.ends_with("/v1") || base.ends_with("/v1/") {
239 path.strip_prefix("/v1").unwrap_or(path)
240 } else {
241 path
242 };
243
244 Ok(format!("{base}{adjusted_path}"))
245}
246
247fn get_api_key(agent: &Prompty) -> Result<String, InvokerError> {
248 let conn = resolve_connection(agent)?;
249
250 if let Some(key) = conn
252 .get("apiKey")
253 .or(conn.get("api_key"))
254 .and_then(|k| k.as_str())
255 {
256 if !key.is_empty() {
257 return Ok(key.to_string());
258 }
259 }
260
261 if let Ok(key) = std::env::var("OPENAI_API_KEY") {
263 if !key.is_empty() {
264 return Ok(key);
265 }
266 }
267
268 Err(InvokerError::Execute(
269 "No API key found. Set OPENAI_API_KEY or configure model.connection.apiKey"
270 .to_string()
271 .into(),
272 ))
273}
274
275use std::collections::VecDeque;
280use std::pin::Pin;
281use std::task::{Context, Poll};
282
283use bytes::Bytes;
284use futures::Stream;
285
286struct SseParser {
293 inner: Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>,
294 buffer: String,
295 pending: VecDeque<Value>,
296 done: bool,
297}
298
299impl SseParser {
300 fn new(inner: impl Stream<Item = Result<Bytes, reqwest::Error>> + Send + 'static) -> Self {
301 Self {
302 inner: Box::pin(inner),
303 buffer: String::new(),
304 pending: VecDeque::new(),
305 done: false,
306 }
307 }
308
309 fn parse_buffer(&mut self) {
310 while let Some(pos) = self.buffer.find("\n\n") {
312 let event = self.buffer[..pos].to_string();
313 self.buffer = self.buffer[pos + 2..].to_string();
314
315 for line in event.lines() {
316 if let Some(data) = line
317 .strip_prefix("data: ")
318 .or_else(|| line.strip_prefix("data:"))
319 {
320 let data = data.trim();
321 if data == "[DONE]" {
322 self.done = true;
323 return;
324 }
325 match serde_json::from_str::<Value>(data) {
326 Ok(parsed) => self.pending.push_back(parsed),
327 Err(e) => {
328 self.pending.push_back(serde_json::json!({
331 "error": {
332 "type": "sse_parse_error",
333 "message": format!("Failed to parse SSE data: {e}"),
334 "raw": data,
335 }
336 }));
337 }
338 }
339 }
340 }
341 }
342 }
343}
344
345impl Stream for SseParser {
346 type Item = Value;
347
348 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
349 loop {
350 if let Some(item) = self.pending.pop_front() {
352 return Poll::Ready(Some(item));
353 }
354 if self.done {
355 return Poll::Ready(None);
356 }
357
358 match self.inner.as_mut().poll_next(cx) {
360 Poll::Ready(Some(Ok(bytes))) => {
361 match std::str::from_utf8(&bytes) {
362 Ok(text) => self.buffer.push_str(text),
363 Err(e) => {
364 self.pending.push_back(serde_json::json!({
366 "error": {
367 "type": "sse_decode_error",
368 "message": format!("Invalid UTF-8 in SSE stream: {e}"),
369 }
370 }));
371 }
372 }
373 self.parse_buffer();
374 }
375 Poll::Ready(Some(Err(e))) => {
376 self.pending.push_back(serde_json::json!({
378 "error": {
379 "type": "sse_transport_error",
380 "message": format!("SSE stream error: {e}"),
381 }
382 }));
383 self.done = true;
384 if let Some(item) = self.pending.pop_front() {
386 return Poll::Ready(Some(item));
387 }
388 return Poll::Ready(None);
389 }
390 Poll::Ready(None) => {
391 if !self.buffer.is_empty() {
393 self.buffer.push_str("\n\n");
394 self.parse_buffer();
395 }
396 if let Some(item) = self.pending.pop_front() {
397 return Poll::Ready(Some(item));
398 }
399 return Poll::Ready(None);
400 }
401 Poll::Pending => return Poll::Pending,
402 }
403 }
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use prompty::model::Prompty;
411 use prompty::model::context::LoadContext;
412 use serde_json::json;
413 use serial_test::serial;
414
415 fn make_agent(model_json: Value) -> Prompty {
416 let mut data = json!({
417 "name": "test",
418 "kind": "prompt",
419 "model": model_json,
420 });
421 data["instructions"] = json!("test");
422 Prompty::load_from_value(&data, &LoadContext::default())
423 }
424
425 #[test]
426 #[serial]
427 fn test_build_url_default() {
428 let agent = make_agent(json!({"id": "gpt-4"}));
429 let url = build_url(&agent, "/v1/chat/completions").unwrap();
430 assert_eq!(url, "https://api.openai.com/v1/chat/completions");
431 }
432
433 #[test]
434 #[serial]
435 fn test_build_url_custom_endpoint() {
436 let agent = make_agent(json!({
437 "id": "gpt-4",
438 "connection": {
439 "kind": "key",
440 "endpoint": "https://custom.openai.com/",
441 "apiKey": "sk-test"
442 }
443 }));
444 let url = build_url(&agent, "/v1/chat/completions").unwrap();
445 assert_eq!(url, "https://custom.openai.com/v1/chat/completions");
446 }
447
448 #[test]
449 #[serial]
450 fn test_get_api_key_from_connection() {
451 let agent = make_agent(json!({
452 "id": "gpt-4",
453 "connection": {
454 "kind": "key",
455 "endpoint": "https://api.openai.com",
456 "apiKey": "sk-from-connection"
457 }
458 }));
459 let key = get_api_key(&agent).unwrap();
460 assert_eq!(key, "sk-from-connection");
461 }
462
463 #[test]
464 #[serial]
465 fn test_build_args_chat() {
466 let agent = make_agent(json!({"id": "gpt-4", "apiType": "chat"}));
467 let messages = vec![Message::with_text(prompty::Role::User, "Hello")];
468 let args = OpenAIExecutor::build_args(&agent, &messages).unwrap();
469 assert_eq!(args["model"], "gpt-4");
470 assert!(args["messages"].is_array());
471 }
472
473 #[test]
474 #[serial]
475 fn test_build_args_embedding() {
476 let agent = make_agent(json!({"id": "text-embedding-3-small", "apiType": "embedding"}));
477 let messages = vec![Message::with_text(prompty::Role::User, "Hello world")];
478 let args = OpenAIExecutor::build_args(&agent, &messages).unwrap();
479 assert_eq!(args["model"], "text-embedding-3-small");
480 assert!(args.get("input").is_some());
481 }
482
483 #[tokio::test]
484 #[serial]
485 async fn test_sse_parser_basic() {
486 use futures::StreamExt;
487
488 let sse_data = b"data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n\
489 data: {\"choices\":[{\"delta\":{\"content\":\" world\"}}]}\n\n\
490 data: [DONE]\n\n";
491
492 let byte_stream = futures::stream::once(async {
493 Ok::<bytes::Bytes, reqwest::Error>(bytes::Bytes::from(&sse_data[..]))
494 });
495
496 let parser = SseParser::new(byte_stream);
497 let items: Vec<Value> = parser.collect().await;
498
499 assert_eq!(items.len(), 2);
500 assert_eq!(items[0]["choices"][0]["delta"]["content"], "Hello");
501 assert_eq!(items[1]["choices"][0]["delta"]["content"], " world");
502 }
503
504 #[tokio::test]
505 #[serial]
506 async fn test_sse_parser_multi_chunk() {
507 use futures::StreamExt;
508
509 let byte_stream = futures::stream::iter(vec![
511 Ok::<bytes::Bytes, reqwest::Error>(bytes::Bytes::from("data: {\"id\":1}\n")),
512 Ok(bytes::Bytes::from("\ndata: {\"id\":2}\n\ndata: [DONE]\n\n")),
513 ]);
514
515 let parser = SseParser::new(byte_stream);
516 let items: Vec<Value> = parser.collect().await;
517
518 assert_eq!(items.len(), 2);
519 assert_eq!(items[0]["id"], 1);
520 assert_eq!(items[1]["id"], 2);
521 }
522
523 #[test]
526 #[serial]
527 fn test_resolve_connection_passthrough_key() {
528 let agent = make_agent(json!({
530 "id": "gpt-4",
531 "connection": {
532 "kind": "key",
533 "endpoint": "https://api.openai.com",
534 "apiKey": "sk-test"
535 }
536 }));
537 let conn = resolve_connection(&agent).unwrap();
538 assert_eq!(conn.get("kind").unwrap().as_str().unwrap(), "key");
539 assert_eq!(conn.get("apiKey").unwrap().as_str().unwrap(), "sk-test");
540 }
541
542 #[test]
543 #[serial]
544 fn test_resolve_connection_reference_missing_name() {
545 let agent = make_agent(json!({
546 "id": "gpt-4",
547 "connection": {
548 "kind": "reference"
549 }
551 }));
552 let result = resolve_connection(&agent);
553 assert!(result.is_err());
554 assert!(result.unwrap_err().to_string().contains("name"));
555 }
556
557 #[test]
558 #[serial]
559 fn test_resolve_connection_reference_not_registered() {
560 prompty::connections::clear_connections();
561 let agent = make_agent(json!({
562 "id": "gpt-4",
563 "connection": {
564 "kind": "reference",
565 "name": "unregistered"
566 }
567 }));
568 let result = resolve_connection(&agent);
569 assert!(result.is_err());
570 assert!(result.unwrap_err().to_string().contains("not registered"));
571 }
572
573 #[test]
574 #[serial]
575 fn test_resolve_connection_reference_success() {
576 prompty::connections::clear_connections();
577 prompty::connections::register_connection(
579 "my-openai",
580 json!({
581 "kind": "key",
582 "endpoint": "https://custom.openai.com",
583 "apiKey": "sk-resolved"
584 }),
585 );
586
587 let agent = make_agent(json!({
588 "id": "gpt-4",
589 "connection": {
590 "kind": "reference",
591 "name": "my-openai"
592 }
593 }));
594
595 let conn = resolve_connection(&agent).unwrap();
596 assert_eq!(
597 conn.get("endpoint").unwrap().as_str().unwrap(),
598 "https://custom.openai.com"
599 );
600 assert_eq!(conn.get("apiKey").unwrap().as_str().unwrap(), "sk-resolved");
601
602 prompty::connections::clear_connections();
604 }
605
606 #[test]
607 #[serial]
608 fn test_reference_connection_flows_to_build_url() {
609 prompty::connections::clear_connections();
610 prompty::connections::register_connection(
611 "prod-openai",
612 json!({
613 "kind": "key",
614 "endpoint": "https://prod.openai.proxy.com",
615 "apiKey": "sk-prod"
616 }),
617 );
618
619 let agent = make_agent(json!({
620 "id": "gpt-4",
621 "connection": {
622 "kind": "reference",
623 "name": "prod-openai"
624 }
625 }));
626
627 let url = build_url(&agent, "/v1/chat/completions").unwrap();
628 assert_eq!(url, "https://prod.openai.proxy.com/v1/chat/completions");
629
630 let key = get_api_key(&agent).unwrap();
631 assert_eq!(key, "sk-prod");
632
633 prompty::connections::clear_connections();
634 }
635}