Skip to main content

prompty_openai/
executor.rs

1//! OpenAI executor — sends requests to the OpenAI Chat Completions API.
2//!
3//! Dispatches on `agent.model.apiType` to call the appropriate endpoint:
4//! `chat`, `embedding`, or `image`.
5
6use async_trait::async_trait;
7use serde_json::Value;
8use std::sync::LazyLock;
9
10use prompty::interfaces::{Executor, InvokerError};
11use prompty::model::Prompty;
12use prompty::types::Message;
13
14use crate::wire;
15
16/// Shared HTTP client — reuses connection pool across requests.
17static HTTP_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(reqwest::Client::new);
18
19/// OpenAI executor implementing the `Executor` trait.
20pub struct OpenAIExecutor;
21
22#[async_trait]
23impl Executor for OpenAIExecutor {
24    async fn execute(&self, agent: &Prompty, messages: &[Message]) -> Result<Value, InvokerError> {
25        let api_type = agent
26            .model
27            .api_type
28            .as_ref()
29            .map(|t| t.as_str())
30            .unwrap_or("chat");
31
32        let (url, body) = match api_type {
33            "chat" | "agent" => {
34                let args = wire::build_chat_args(agent, messages);
35                let url = build_url(agent, "/v1/chat/completions")?;
36                (url, args)
37            }
38            "responses" => {
39                let args = wire::build_responses_args(agent, messages);
40                let url = build_url(agent, "/v1/responses")?;
41                (url, args)
42            }
43            "embedding" => {
44                let args = wire::build_embedding_args(agent, messages);
45                let url = build_url(agent, "/v1/embeddings")?;
46                (url, args)
47            }
48            "image" => {
49                let args = wire::build_image_args(agent, messages);
50                let url = build_url(agent, "/v1/images/generations")?;
51                (url, args)
52            }
53            other => {
54                return Err(InvokerError::Execute(
55                    format!("Unsupported apiType: {other}").into(),
56                ));
57            }
58        };
59
60        let api_key = get_api_key(agent)?;
61        let client = &*HTTP_CLIENT;
62        let response = client
63            .post(&url)
64            .header("Authorization", format!("Bearer {api_key}"))
65            .header("Content-Type", "application/json")
66            .json(&body)
67            .send()
68            .await
69            .map_err(|e| InvokerError::Execute(format!("HTTP request failed: {e}").into()))?;
70
71        if !response.status().is_success() {
72            let status = response.status();
73            let body_text = response
74                .text()
75                .await
76                .unwrap_or_else(|_| "unable to read body".to_string());
77            return Err(InvokerError::Execute(
78                format!("OpenAI API error (HTTP {status}): {body_text}").into(),
79            ));
80        }
81
82        let result: Value = response
83            .json()
84            .await
85            .map_err(|e| InvokerError::Execute(format!("Failed to parse response: {e}").into()))?;
86
87        Ok(result)
88    }
89
90    fn format_tool_messages(
91        &self,
92        _raw_response: &serde_json::Value,
93        tool_calls: &[prompty::types::ToolCall],
94        tool_results: &[String],
95        _text_content: Option<&str>,
96    ) -> Vec<Message> {
97        wire::format_tool_messages(tool_calls, tool_results)
98    }
99
100    async fn execute_stream(
101        &self,
102        agent: &Prompty,
103        messages: &[Message],
104    ) -> Result<std::pin::Pin<Box<dyn futures::Stream<Item = Value> + Send>>, InvokerError> {
105        let api_type = agent
106            .model
107            .api_type
108            .as_ref()
109            .map(|t| t.as_str())
110            .unwrap_or("chat");
111
112        let (url, mut body) = match api_type {
113            "chat" | "agent" => {
114                let args = wire::build_chat_args(agent, messages);
115                let url = build_url(agent, "/v1/chat/completions")?;
116                (url, args)
117            }
118            "responses" => {
119                let args = wire::build_responses_args(agent, messages);
120                let url = build_url(agent, "/v1/responses")?;
121                (url, args)
122            }
123            other => {
124                return Err(InvokerError::Execute(
125                    format!("Streaming not supported for apiType: {other}").into(),
126                ));
127            }
128        };
129
130        // Force stream: true
131        if let Some(obj) = body.as_object_mut() {
132            obj.insert("stream".into(), Value::Bool(true));
133        }
134
135        let api_key = get_api_key(agent)?;
136        let client = &*HTTP_CLIENT;
137        let response = client
138            .post(&url)
139            .header("Authorization", format!("Bearer {api_key}"))
140            .header("Content-Type", "application/json")
141            .json(&body)
142            .send()
143            .await
144            .map_err(|e| InvokerError::Execute(format!("HTTP request failed: {e}").into()))?;
145
146        if !response.status().is_success() {
147            let status = response.status();
148            let body_text = response
149                .text()
150                .await
151                .unwrap_or_else(|_| "unable to read body".to_string());
152            return Err(InvokerError::Execute(
153                format!("OpenAI API error (HTTP {status}): {body_text}").into(),
154            ));
155        }
156
157        let byte_stream = response.bytes_stream();
158        Ok(Box::pin(SseParser::new(byte_stream)))
159    }
160}
161
162impl OpenAIExecutor {
163    /// Build the request args without sending — useful for testing wire format.
164    pub fn build_args(agent: &Prompty, messages: &[Message]) -> Result<Value, InvokerError> {
165        let api_type = agent
166            .model
167            .api_type
168            .as_ref()
169            .map(|t| t.as_str())
170            .unwrap_or("chat");
171        Ok(match api_type {
172            "chat" | "agent" => wire::build_chat_args(agent, messages),
173            "embedding" => wire::build_embedding_args(agent, messages),
174            "image" => wire::build_image_args(agent, messages),
175            other => {
176                return Err(InvokerError::Execute(
177                    format!("Unsupported apiType: {other}").into(),
178                ));
179            }
180        })
181    }
182}
183
184// ---------------------------------------------------------------------------
185// Helpers
186// ---------------------------------------------------------------------------
187
188/// Resolve the effective connection — if `kind == "reference"`, look up the
189/// named connection from the registry. Otherwise return the connection as-is.
190fn resolve_connection(
191    agent: &Prompty,
192) -> Result<std::borrow::Cow<'_, serde_json::Value>, InvokerError> {
193    let conn = &agent.model.connection;
194    let kind = conn.get("kind").and_then(|k| k.as_str()).unwrap_or("");
195
196    if kind == "reference" {
197        let name = conn.get("name").and_then(|n| n.as_str()).ok_or_else(|| {
198            InvokerError::Execute(
199                "Reference connection missing 'name' field"
200                    .to_string()
201                    .into(),
202            )
203        })?;
204
205        // Look up the named connection from the registry
206        let resolved =
207            prompty::connections::with_connection::<serde_json::Value, _>(name, |c| c.clone())
208                .map_err(|e| InvokerError::Execute(e.into()))?;
209
210        Ok(std::borrow::Cow::Owned(resolved))
211    } else {
212        Ok(std::borrow::Cow::Borrowed(conn))
213    }
214}
215
216fn build_url(agent: &Prompty, path: &str) -> Result<String, InvokerError> {
217    let conn = resolve_connection(agent)?;
218
219    // 1. connection.endpoint from the agent
220    // 2. OPENAI_BASE_URL env var (matches OpenAI SDK behavior)
221    // 3. default https://api.openai.com
222    let endpoint = conn
223        .get("endpoint")
224        .and_then(|e| e.as_str())
225        .filter(|s| !s.is_empty())
226        .map(String::from)
227        .or_else(|| {
228            std::env::var("OPENAI_BASE_URL")
229                .ok()
230                .filter(|s| !s.is_empty())
231        })
232        .unwrap_or_else(|| "https://api.openai.com".to_string());
233
234    let base = endpoint.trim_end_matches('/');
235
236    // If base already includes /v1 (e.g. OPENAI_BASE_URL="https://proxy.example.com/openai/v1"),
237    // strip the leading /v1 from the path to avoid duplication.
238    let adjusted_path = if base.ends_with("/v1") || base.ends_with("/v1/") {
239        path.strip_prefix("/v1").unwrap_or(path)
240    } else {
241        path
242    };
243
244    Ok(format!("{base}{adjusted_path}"))
245}
246
247fn get_api_key(agent: &Prompty) -> Result<String, InvokerError> {
248    let conn = resolve_connection(agent)?;
249
250    // Try connection.apiKey first
251    if let Some(key) = conn
252        .get("apiKey")
253        .or(conn.get("api_key"))
254        .and_then(|k| k.as_str())
255    {
256        if !key.is_empty() {
257            return Ok(key.to_string());
258        }
259    }
260
261    // Fall back to OPENAI_API_KEY env var
262    if let Ok(key) = std::env::var("OPENAI_API_KEY") {
263        if !key.is_empty() {
264            return Ok(key);
265        }
266    }
267
268    Err(InvokerError::Execute(
269        "No API key found. Set OPENAI_API_KEY or configure model.connection.apiKey"
270            .to_string()
271            .into(),
272    ))
273}
274
275// ---------------------------------------------------------------------------
276// SSE stream parser — converts raw HTTP byte stream to JSON Value stream
277// ---------------------------------------------------------------------------
278
279use std::collections::VecDeque;
280use std::pin::Pin;
281use std::task::{Context, Poll};
282
283use bytes::Bytes;
284use futures::Stream;
285
286/// Parses Server-Sent Events (SSE) from a raw byte stream into JSON `Value` items.
287///
288/// Handles:
289/// - `data: [DONE]` → terminates the stream
290/// - `data: {...}` → yields parsed JSON
291/// - Multi-line buffers (splits on `\n\n`)
292struct SseParser {
293    inner: Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>,
294    buffer: String,
295    pending: VecDeque<Value>,
296    done: bool,
297}
298
299impl SseParser {
300    fn new(inner: impl Stream<Item = Result<Bytes, reqwest::Error>> + Send + 'static) -> Self {
301        Self {
302            inner: Box::pin(inner),
303            buffer: String::new(),
304            pending: VecDeque::new(),
305            done: false,
306        }
307    }
308
309    fn parse_buffer(&mut self) {
310        // SSE events are separated by double newlines
311        while let Some(pos) = self.buffer.find("\n\n") {
312            let event = self.buffer[..pos].to_string();
313            self.buffer = self.buffer[pos + 2..].to_string();
314
315            for line in event.lines() {
316                if let Some(data) = line
317                    .strip_prefix("data: ")
318                    .or_else(|| line.strip_prefix("data:"))
319                {
320                    let data = data.trim();
321                    if data == "[DONE]" {
322                        self.done = true;
323                        return;
324                    }
325                    match serde_json::from_str::<Value>(data) {
326                        Ok(parsed) => self.pending.push_back(parsed),
327                        Err(e) => {
328                            // Surface SSE JSON parse errors as error events
329                            // so consumers can detect malformed responses
330                            self.pending.push_back(serde_json::json!({
331                                "error": {
332                                    "type": "sse_parse_error",
333                                    "message": format!("Failed to parse SSE data: {e}"),
334                                    "raw": data,
335                                }
336                            }));
337                        }
338                    }
339                }
340            }
341        }
342    }
343}
344
345impl Stream for SseParser {
346    type Item = Value;
347
348    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
349        loop {
350            // Drain pending items first
351            if let Some(item) = self.pending.pop_front() {
352                return Poll::Ready(Some(item));
353            }
354            if self.done {
355                return Poll::Ready(None);
356            }
357
358            // Pull more bytes from the inner stream
359            match self.inner.as_mut().poll_next(cx) {
360                Poll::Ready(Some(Ok(bytes))) => {
361                    match std::str::from_utf8(&bytes) {
362                        Ok(text) => self.buffer.push_str(text),
363                        Err(e) => {
364                            // Surface UTF-8 decode errors
365                            self.pending.push_back(serde_json::json!({
366                                "error": {
367                                    "type": "sse_decode_error",
368                                    "message": format!("Invalid UTF-8 in SSE stream: {e}"),
369                                }
370                            }));
371                        }
372                    }
373                    self.parse_buffer();
374                }
375                Poll::Ready(Some(Err(e))) => {
376                    // Surface transport errors instead of silently terminating
377                    self.pending.push_back(serde_json::json!({
378                        "error": {
379                            "type": "sse_transport_error",
380                            "message": format!("SSE stream error: {e}"),
381                        }
382                    }));
383                    self.done = true;
384                    // Drain pending (including error) before ending
385                    if let Some(item) = self.pending.pop_front() {
386                        return Poll::Ready(Some(item));
387                    }
388                    return Poll::Ready(None);
389                }
390                Poll::Ready(None) => {
391                    // Final buffer flush
392                    if !self.buffer.is_empty() {
393                        self.buffer.push_str("\n\n");
394                        self.parse_buffer();
395                    }
396                    if let Some(item) = self.pending.pop_front() {
397                        return Poll::Ready(Some(item));
398                    }
399                    return Poll::Ready(None);
400                }
401                Poll::Pending => return Poll::Pending,
402            }
403        }
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use prompty::model::Prompty;
411    use prompty::model::context::LoadContext;
412    use serde_json::json;
413    use serial_test::serial;
414
415    fn make_agent(model_json: Value) -> Prompty {
416        let mut data = json!({
417            "name": "test",
418            "kind": "prompt",
419            "model": model_json,
420        });
421        data["instructions"] = json!("test");
422        Prompty::load_from_value(&data, &LoadContext::default())
423    }
424
425    #[test]
426    #[serial]
427    fn test_build_url_default() {
428        let agent = make_agent(json!({"id": "gpt-4"}));
429        let url = build_url(&agent, "/v1/chat/completions").unwrap();
430        assert_eq!(url, "https://api.openai.com/v1/chat/completions");
431    }
432
433    #[test]
434    #[serial]
435    fn test_build_url_custom_endpoint() {
436        let agent = make_agent(json!({
437            "id": "gpt-4",
438            "connection": {
439                "kind": "key",
440                "endpoint": "https://custom.openai.com/",
441                "apiKey": "sk-test"
442            }
443        }));
444        let url = build_url(&agent, "/v1/chat/completions").unwrap();
445        assert_eq!(url, "https://custom.openai.com/v1/chat/completions");
446    }
447
448    #[test]
449    #[serial]
450    fn test_get_api_key_from_connection() {
451        let agent = make_agent(json!({
452            "id": "gpt-4",
453            "connection": {
454                "kind": "key",
455                "endpoint": "https://api.openai.com",
456                "apiKey": "sk-from-connection"
457            }
458        }));
459        let key = get_api_key(&agent).unwrap();
460        assert_eq!(key, "sk-from-connection");
461    }
462
463    #[test]
464    #[serial]
465    fn test_build_args_chat() {
466        let agent = make_agent(json!({"id": "gpt-4", "apiType": "chat"}));
467        let messages = vec![Message::with_text(prompty::Role::User, "Hello")];
468        let args = OpenAIExecutor::build_args(&agent, &messages).unwrap();
469        assert_eq!(args["model"], "gpt-4");
470        assert!(args["messages"].is_array());
471    }
472
473    #[test]
474    #[serial]
475    fn test_build_args_embedding() {
476        let agent = make_agent(json!({"id": "text-embedding-3-small", "apiType": "embedding"}));
477        let messages = vec![Message::with_text(prompty::Role::User, "Hello world")];
478        let args = OpenAIExecutor::build_args(&agent, &messages).unwrap();
479        assert_eq!(args["model"], "text-embedding-3-small");
480        assert!(args.get("input").is_some());
481    }
482
483    #[tokio::test]
484    #[serial]
485    async fn test_sse_parser_basic() {
486        use futures::StreamExt;
487
488        let sse_data = b"data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n\
489                         data: {\"choices\":[{\"delta\":{\"content\":\" world\"}}]}\n\n\
490                         data: [DONE]\n\n";
491
492        let byte_stream = futures::stream::once(async {
493            Ok::<bytes::Bytes, reqwest::Error>(bytes::Bytes::from(&sse_data[..]))
494        });
495
496        let parser = SseParser::new(byte_stream);
497        let items: Vec<Value> = parser.collect().await;
498
499        assert_eq!(items.len(), 2);
500        assert_eq!(items[0]["choices"][0]["delta"]["content"], "Hello");
501        assert_eq!(items[1]["choices"][0]["delta"]["content"], " world");
502    }
503
504    #[tokio::test]
505    #[serial]
506    async fn test_sse_parser_multi_chunk() {
507        use futures::StreamExt;
508
509        // Simulate data arriving in two separate network chunks
510        let byte_stream = futures::stream::iter(vec![
511            Ok::<bytes::Bytes, reqwest::Error>(bytes::Bytes::from("data: {\"id\":1}\n")),
512            Ok(bytes::Bytes::from("\ndata: {\"id\":2}\n\ndata: [DONE]\n\n")),
513        ]);
514
515        let parser = SseParser::new(byte_stream);
516        let items: Vec<Value> = parser.collect().await;
517
518        assert_eq!(items.len(), 2);
519        assert_eq!(items[0]["id"], 1);
520        assert_eq!(items[1]["id"], 2);
521    }
522
523    // --- Reference connection resolution tests ---
524
525    #[test]
526    #[serial]
527    fn test_resolve_connection_passthrough_key() {
528        // Non-reference connections should pass through unchanged
529        let agent = make_agent(json!({
530            "id": "gpt-4",
531            "connection": {
532                "kind": "key",
533                "endpoint": "https://api.openai.com",
534                "apiKey": "sk-test"
535            }
536        }));
537        let conn = resolve_connection(&agent).unwrap();
538        assert_eq!(conn.get("kind").unwrap().as_str().unwrap(), "key");
539        assert_eq!(conn.get("apiKey").unwrap().as_str().unwrap(), "sk-test");
540    }
541
542    #[test]
543    #[serial]
544    fn test_resolve_connection_reference_missing_name() {
545        let agent = make_agent(json!({
546            "id": "gpt-4",
547            "connection": {
548                "kind": "reference"
549                // missing "name" field
550            }
551        }));
552        let result = resolve_connection(&agent);
553        assert!(result.is_err());
554        assert!(result.unwrap_err().to_string().contains("name"));
555    }
556
557    #[test]
558    #[serial]
559    fn test_resolve_connection_reference_not_registered() {
560        prompty::connections::clear_connections();
561        let agent = make_agent(json!({
562            "id": "gpt-4",
563            "connection": {
564                "kind": "reference",
565                "name": "unregistered"
566            }
567        }));
568        let result = resolve_connection(&agent);
569        assert!(result.is_err());
570        assert!(result.unwrap_err().to_string().contains("not registered"));
571    }
572
573    #[test]
574    #[serial]
575    fn test_resolve_connection_reference_success() {
576        prompty::connections::clear_connections();
577        // Register a connection as a JSON Value
578        prompty::connections::register_connection(
579            "my-openai",
580            json!({
581                "kind": "key",
582                "endpoint": "https://custom.openai.com",
583                "apiKey": "sk-resolved"
584            }),
585        );
586
587        let agent = make_agent(json!({
588            "id": "gpt-4",
589            "connection": {
590                "kind": "reference",
591                "name": "my-openai"
592            }
593        }));
594
595        let conn = resolve_connection(&agent).unwrap();
596        assert_eq!(
597            conn.get("endpoint").unwrap().as_str().unwrap(),
598            "https://custom.openai.com"
599        );
600        assert_eq!(conn.get("apiKey").unwrap().as_str().unwrap(), "sk-resolved");
601
602        // Clean up
603        prompty::connections::clear_connections();
604    }
605
606    #[test]
607    #[serial]
608    fn test_reference_connection_flows_to_build_url() {
609        prompty::connections::clear_connections();
610        prompty::connections::register_connection(
611            "prod-openai",
612            json!({
613                "kind": "key",
614                "endpoint": "https://prod.openai.proxy.com",
615                "apiKey": "sk-prod"
616            }),
617        );
618
619        let agent = make_agent(json!({
620            "id": "gpt-4",
621            "connection": {
622                "kind": "reference",
623                "name": "prod-openai"
624            }
625        }));
626
627        let url = build_url(&agent, "/v1/chat/completions").unwrap();
628        assert_eq!(url, "https://prod.openai.proxy.com/v1/chat/completions");
629
630        let key = get_api_key(&agent).unwrap();
631        assert_eq!(key, "sk-prod");
632
633        prompty::connections::clear_connections();
634    }
635}