Skip to main content

walrus_model/openai/
provider.rs

1//! Model trait implementation for the OpenAI-compatible provider.
2
3use super::OpenAI;
4use anyhow::Result;
5use async_stream::try_stream;
6use compact_str::CompactString;
7use futures_core::Stream;
8use futures_util::StreamExt;
9use reqwest::Method;
10use wcore::model::{Model, Response, StreamChunk};
11
12impl Model for OpenAI {
13    async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
14        let body = super::request::Request::from(request.clone());
15        tracing::trace!("request: {}", serde_json::to_string(&body)?);
16        let response = self
17            .client
18            .request(Method::POST, &self.endpoint)
19            .headers(self.headers.clone())
20            .json(&body)
21            .send()
22            .await?;
23
24        let status = response.status();
25        let text = response.text().await?;
26        if !status.is_success() {
27            anyhow::bail!("OpenAI API error ({status}): {text}");
28        }
29
30        serde_json::from_str(&text).map_err(Into::into)
31    }
32
33    fn stream(
34        &self,
35        request: wcore::model::Request,
36    ) -> impl Stream<Item = Result<StreamChunk>> + Send {
37        let usage = request.usage;
38        let body = super::request::Request::from(request).stream(usage);
39        if let Ok(body) = serde_json::to_string(&body) {
40            tracing::trace!("request: {}", body);
41        }
42        let request = self
43            .client
44            .request(Method::POST, &self.endpoint)
45            .headers(self.headers.clone())
46            .json(&body);
47
48        try_stream! {
49            let response = request.send().await?;
50            let mut stream = response.bytes_stream();
51            while let Some(Ok(bytes)) = stream.next().await {
52                let text = String::from_utf8_lossy(&bytes).into_owned();
53                tracing::trace!("chunk: {}", text);
54                for data in text.split("data: ").skip(1).filter(|s| !s.starts_with("[DONE]")) {
55                    let trimmed = data.trim();
56                    if trimmed.is_empty() {
57                        continue;
58                    }
59                    match serde_json::from_str::<StreamChunk>(trimmed) {
60                        Ok(chunk) => yield chunk,
61                        Err(e) => tracing::warn!("failed to parse chunk: {e}, data: {trimmed}"),
62                    }
63                }
64            }
65        }
66    }
67
68    fn active_model(&self) -> CompactString {
69        CompactString::from("gpt-4o")
70    }
71}