Skip to main content

langfuse_openai/
wrapper.rs

1//! Traced wrappers around async-openai types that automatically create
2//! Langfuse observation spans for every API call.
3
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use async_openai::Chat;
8use async_openai::Embeddings;
9use async_openai::config::Config;
10use async_openai::error::OpenAIError;
11use async_openai::types::chat::{
12    CreateChatCompletionRequest, CreateChatCompletionResponse, CreateChatCompletionStreamResponse,
13};
14use async_openai::types::embeddings::{CreateEmbeddingRequest, CreateEmbeddingResponse};
15use futures::Stream;
16
17use crate::parser::{self, ToolCallAccumulator};
18use langfuse::{LangfuseEmbedding, LangfuseGeneration};
19use langfuse_core::types::UsageDetails;
20
21/// A wrapper around async-openai's [`Chat`] that automatically creates
22/// Langfuse generation spans for every chat completion API call.
23pub struct TracedChat<'c, C: Config> {
24    inner: Chat<'c, C>,
25}
26
27impl<C: Config> std::fmt::Debug for TracedChat<'_, C> {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("TracedChat").finish()
30    }
31}
32
33impl<'c, C: Config> TracedChat<'c, C> {
34    /// Wrap an existing [`Chat`] instance with Langfuse tracing.
35    #[must_use]
36    pub fn new(chat: Chat<'c, C>) -> Self {
37        Self { inner: chat }
38    }
39
40    /// Create a chat completion with automatic Langfuse tracing.
41    ///
42    /// A generation span is created before the request and ended after the
43    /// response is received. Model, usage, input, and output are recorded
44    /// automatically.
45    ///
46    /// # Errors
47    ///
48    /// Returns the underlying [`OpenAIError`] if the API call fails.
49    pub async fn create(
50        &self,
51        request: CreateChatCompletionRequest,
52    ) -> Result<CreateChatCompletionResponse, OpenAIError> {
53        let generation = LangfuseGeneration::start("chat-completion");
54        generation.set_input(&request);
55
56        match self.inner.create(request).await {
57            Ok(response) => {
58                generation.set_model(&parser::extract_model(&response));
59                if let Some(usage) = parser::extract_usage(&response) {
60                    generation.set_usage(&usage);
61                }
62                generation.set_output(&parser::extract_output(&response));
63                if let Some(tool_calls) = parser::extract_tool_calls(&response) {
64                    generation.set_tool_calls(&tool_calls);
65                }
66                generation.end();
67                Ok(response)
68            }
69            Err(err) => {
70                generation.set_level(langfuse_core::types::SpanLevel::Error);
71                generation.set_status_message(&err.to_string());
72                generation.end();
73                Err(err)
74            }
75        }
76    }
77
78    /// Create a streaming chat completion with automatic Langfuse tracing.
79    ///
80    /// A generation span is created before the request. The returned
81    /// [`TracedStream`] accumulates content from delta chunks and records
82    /// `completion_start_time` on the first chunk. The span is finalized
83    /// when the stream ends or is dropped.
84    ///
85    /// # Errors
86    ///
87    /// Returns the underlying [`OpenAIError`] if the API call fails.
88    pub async fn create_stream(
89        &self,
90        request: CreateChatCompletionRequest,
91    ) -> Result<TracedStream, OpenAIError> {
92        let generation = LangfuseGeneration::start("chat-completion");
93        generation.set_input(&request);
94
95        let stream = self.inner.create_stream(request).await?;
96        Ok(TracedStream::new(stream, generation))
97    }
98}
99
100/// Create a [`TracedChat`] wrapper from an async-openai
101/// [`Client`](async_openai::Client).
102pub fn observe_openai<C: Config>(client: &async_openai::Client<C>) -> TracedChat<'_, C> {
103    TracedChat::new(client.chat())
104}
105
106// ---------------------------------------------------------------------------
107// TracedStream
108// ---------------------------------------------------------------------------
109
110/// A stream wrapper that accumulates content and records Langfuse
111/// generation attributes as chunks arrive.
112///
113/// On the first chunk, `completion_start_time` is set. On each chunk,
114/// delta content is accumulated. When the stream ends (or on drop), the
115/// generation span is finalized with model, usage, and output.
116pub struct TracedStream {
117    inner:
118        Pin<Box<dyn Stream<Item = Result<CreateChatCompletionStreamResponse, OpenAIError>> + Send>>,
119    generation: Option<LangfuseGeneration>,
120    accumulated_content: String,
121    model: Option<String>,
122    first_chunk: bool,
123    tool_call_acc: ToolCallAccumulator,
124}
125
126impl std::fmt::Debug for TracedStream {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        f.debug_struct("TracedStream")
129            .field("first_chunk", &self.first_chunk)
130            .field("model", &self.model)
131            .field("accumulated_content", &self.accumulated_content)
132            .finish_non_exhaustive()
133    }
134}
135
136impl TracedStream {
137    fn new(
138        inner: Pin<
139            Box<dyn Stream<Item = Result<CreateChatCompletionStreamResponse, OpenAIError>> + Send>,
140        >,
141        generation: LangfuseGeneration,
142    ) -> Self {
143        Self {
144            inner,
145            generation: Some(generation),
146            accumulated_content: String::new(),
147            model: None,
148            first_chunk: true,
149            tool_call_acc: ToolCallAccumulator::new(),
150        }
151    }
152
153    /// End the generation span with whatever data has been accumulated so far.
154    fn finalize(&mut self) {
155        if let Some(generation) = self.generation.take() {
156            if let Some(model) = &self.model {
157                generation.set_model(model);
158            }
159            if !self.accumulated_content.is_empty() {
160                generation.set_output(&self.accumulated_content);
161            }
162            if self.tool_call_acc.has_calls() {
163                generation.set_tool_calls(&self.tool_call_acc.finalize());
164            }
165            generation.end();
166        }
167    }
168}
169
170impl Stream for TracedStream {
171    type Item = Result<CreateChatCompletionStreamResponse, OpenAIError>;
172
173    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
174        // `Pin<Box<dyn Stream + Send>>` is `Unpin`, so we can safely get `&mut Self`.
175        let this = self.get_mut();
176
177        match this.inner.as_mut().poll_next(cx) {
178            Poll::Ready(Some(Ok(chunk))) => {
179                // Record completion_start_time on the very first chunk.
180                if this.first_chunk {
181                    this.first_chunk = false;
182                    if let Some(span) = this.generation.as_ref() {
183                        span.set_completion_start_time(&chrono::Utc::now());
184                    }
185                }
186
187                // Capture model from the first chunk that has it.
188                if this.model.is_none() {
189                    this.model = Some(chunk.model.clone());
190                }
191
192                // Accumulate delta content.
193                if let Some(content) = parser::extract_stream_chunk_content(&chunk) {
194                    this.accumulated_content.push_str(&content);
195                }
196
197                // If this chunk carries usage (final chunk), record it.
198                if let Some(usage) = parser::extract_stream_usage(&chunk)
199                    && let Some(span) = this.generation.as_ref()
200                {
201                    span.set_usage(&usage);
202                }
203
204                // Accumulate tool call deltas.
205                this.tool_call_acc.accumulate(&chunk);
206
207                Poll::Ready(Some(Ok(chunk)))
208            }
209            Poll::Ready(Some(Err(err))) => {
210                // Record the error on the generation span and finalize.
211                if let Some(span) = this.generation.as_ref() {
212                    span.set_level(langfuse_core::types::SpanLevel::Error);
213                    span.set_status_message(&err.to_string());
214                }
215                this.finalize();
216                Poll::Ready(Some(Err(err)))
217            }
218            Poll::Ready(None) => {
219                // Stream ended — finalize the generation span.
220                this.finalize();
221                Poll::Ready(None)
222            }
223            Poll::Pending => Poll::Pending,
224        }
225    }
226}
227
228impl Drop for TracedStream {
229    fn drop(&mut self) {
230        self.finalize();
231    }
232}
233
234// ---------------------------------------------------------------------------
235// TracedEmbeddings
236// ---------------------------------------------------------------------------
237
238/// A wrapper around async-openai's [`Embeddings`] that automatically creates
239/// Langfuse embedding spans for every embedding API call.
240pub struct TracedEmbeddings<'c, C: Config> {
241    inner: Embeddings<'c, C>,
242}
243
244impl<C: Config> std::fmt::Debug for TracedEmbeddings<'_, C> {
245    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246        f.debug_struct("TracedEmbeddings").finish()
247    }
248}
249
250impl<'c, C: Config> TracedEmbeddings<'c, C> {
251    /// Wrap an existing [`Embeddings`] instance with Langfuse tracing.
252    #[must_use]
253    pub fn new(embeddings: Embeddings<'c, C>) -> Self {
254        Self { inner: embeddings }
255    }
256
257    /// Create an embedding with automatic Langfuse tracing.
258    ///
259    /// An embedding span is created before the request and ended after the
260    /// response is received. Model, usage, and input are recorded
261    /// automatically. Output vectors are intentionally omitted (too large).
262    ///
263    /// # Errors
264    ///
265    /// Returns the underlying [`OpenAIError`] if the API call fails.
266    pub async fn create(
267        &self,
268        request: CreateEmbeddingRequest,
269    ) -> Result<CreateEmbeddingResponse, OpenAIError> {
270        let embedding = LangfuseEmbedding::start("embedding");
271        embedding.set_input(&serde_json::json!(request.input));
272        embedding.set_model(&request.model);
273
274        match self.inner.create(request).await {
275            Ok(response) => {
276                // Record the model actually used (may differ from request).
277                embedding.set_model(&response.model);
278                embedding.set_usage(&UsageDetails {
279                    input: Some(u64::from(response.usage.prompt_tokens)),
280                    output: None,
281                    total: Some(u64::from(response.usage.total_tokens)),
282                });
283                // Intentionally skip output — embedding vectors are too large.
284                embedding.end();
285                Ok(response)
286            }
287            Err(err) => {
288                embedding.set_level(langfuse_core::types::SpanLevel::Error);
289                embedding.set_status_message(&err.to_string());
290                embedding.end();
291                Err(err)
292            }
293        }
294    }
295}
296
297/// Create a [`TracedEmbeddings`] wrapper from an async-openai
298/// [`Client`](async_openai::Client).
299pub fn observe_openai_embeddings<C: Config>(
300    client: &async_openai::Client<C>,
301) -> TracedEmbeddings<'_, C> {
302    TracedEmbeddings::new(client.embeddings())
303}