langfuse_openai/
wrapper.rs1use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use async_openai::Chat;
8use async_openai::Embeddings;
9use async_openai::config::Config;
10use async_openai::error::OpenAIError;
11use async_openai::types::chat::{
12 CreateChatCompletionRequest, CreateChatCompletionResponse, CreateChatCompletionStreamResponse,
13};
14use async_openai::types::embeddings::{CreateEmbeddingRequest, CreateEmbeddingResponse};
15use futures::Stream;
16
17use crate::parser::{self, ToolCallAccumulator};
18use langfuse::{LangfuseEmbedding, LangfuseGeneration};
19use langfuse_core::types::UsageDetails;
20
21pub struct TracedChat<'c, C: Config> {
24 inner: Chat<'c, C>,
25}
26
27impl<C: Config> std::fmt::Debug for TracedChat<'_, C> {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("TracedChat").finish()
30 }
31}
32
33impl<'c, C: Config> TracedChat<'c, C> {
34 #[must_use]
36 pub fn new(chat: Chat<'c, C>) -> Self {
37 Self { inner: chat }
38 }
39
40 pub async fn create(
50 &self,
51 request: CreateChatCompletionRequest,
52 ) -> Result<CreateChatCompletionResponse, OpenAIError> {
53 let generation = LangfuseGeneration::start("chat-completion");
54 generation.set_input(&request);
55
56 match self.inner.create(request).await {
57 Ok(response) => {
58 generation.set_model(&parser::extract_model(&response));
59 if let Some(usage) = parser::extract_usage(&response) {
60 generation.set_usage(&usage);
61 }
62 generation.set_output(&parser::extract_output(&response));
63 if let Some(tool_calls) = parser::extract_tool_calls(&response) {
64 generation.set_tool_calls(&tool_calls);
65 }
66 generation.end();
67 Ok(response)
68 }
69 Err(err) => {
70 generation.set_level(langfuse_core::types::SpanLevel::Error);
71 generation.set_status_message(&err.to_string());
72 generation.end();
73 Err(err)
74 }
75 }
76 }
77
78 pub async fn create_stream(
89 &self,
90 request: CreateChatCompletionRequest,
91 ) -> Result<TracedStream, OpenAIError> {
92 let generation = LangfuseGeneration::start("chat-completion");
93 generation.set_input(&request);
94
95 let stream = self.inner.create_stream(request).await?;
96 Ok(TracedStream::new(stream, generation))
97 }
98}
99
100pub fn observe_openai<C: Config>(client: &async_openai::Client<C>) -> TracedChat<'_, C> {
103 TracedChat::new(client.chat())
104}
105
106pub struct TracedStream {
117 inner:
118 Pin<Box<dyn Stream<Item = Result<CreateChatCompletionStreamResponse, OpenAIError>> + Send>>,
119 generation: Option<LangfuseGeneration>,
120 accumulated_content: String,
121 model: Option<String>,
122 first_chunk: bool,
123 tool_call_acc: ToolCallAccumulator,
124}
125
126impl std::fmt::Debug for TracedStream {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 f.debug_struct("TracedStream")
129 .field("first_chunk", &self.first_chunk)
130 .field("model", &self.model)
131 .field("accumulated_content", &self.accumulated_content)
132 .finish_non_exhaustive()
133 }
134}
135
136impl TracedStream {
137 fn new(
138 inner: Pin<
139 Box<dyn Stream<Item = Result<CreateChatCompletionStreamResponse, OpenAIError>> + Send>,
140 >,
141 generation: LangfuseGeneration,
142 ) -> Self {
143 Self {
144 inner,
145 generation: Some(generation),
146 accumulated_content: String::new(),
147 model: None,
148 first_chunk: true,
149 tool_call_acc: ToolCallAccumulator::new(),
150 }
151 }
152
153 fn finalize(&mut self) {
155 if let Some(generation) = self.generation.take() {
156 if let Some(model) = &self.model {
157 generation.set_model(model);
158 }
159 if !self.accumulated_content.is_empty() {
160 generation.set_output(&self.accumulated_content);
161 }
162 if self.tool_call_acc.has_calls() {
163 generation.set_tool_calls(&self.tool_call_acc.finalize());
164 }
165 generation.end();
166 }
167 }
168}
169
170impl Stream for TracedStream {
171 type Item = Result<CreateChatCompletionStreamResponse, OpenAIError>;
172
173 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
174 let this = self.get_mut();
176
177 match this.inner.as_mut().poll_next(cx) {
178 Poll::Ready(Some(Ok(chunk))) => {
179 if this.first_chunk {
181 this.first_chunk = false;
182 if let Some(span) = this.generation.as_ref() {
183 span.set_completion_start_time(&chrono::Utc::now());
184 }
185 }
186
187 if this.model.is_none() {
189 this.model = Some(chunk.model.clone());
190 }
191
192 if let Some(content) = parser::extract_stream_chunk_content(&chunk) {
194 this.accumulated_content.push_str(&content);
195 }
196
197 if let Some(usage) = parser::extract_stream_usage(&chunk)
199 && let Some(span) = this.generation.as_ref()
200 {
201 span.set_usage(&usage);
202 }
203
204 this.tool_call_acc.accumulate(&chunk);
206
207 Poll::Ready(Some(Ok(chunk)))
208 }
209 Poll::Ready(Some(Err(err))) => {
210 if let Some(span) = this.generation.as_ref() {
212 span.set_level(langfuse_core::types::SpanLevel::Error);
213 span.set_status_message(&err.to_string());
214 }
215 this.finalize();
216 Poll::Ready(Some(Err(err)))
217 }
218 Poll::Ready(None) => {
219 this.finalize();
221 Poll::Ready(None)
222 }
223 Poll::Pending => Poll::Pending,
224 }
225 }
226}
227
228impl Drop for TracedStream {
229 fn drop(&mut self) {
230 self.finalize();
231 }
232}
233
234pub struct TracedEmbeddings<'c, C: Config> {
241 inner: Embeddings<'c, C>,
242}
243
244impl<C: Config> std::fmt::Debug for TracedEmbeddings<'_, C> {
245 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246 f.debug_struct("TracedEmbeddings").finish()
247 }
248}
249
250impl<'c, C: Config> TracedEmbeddings<'c, C> {
251 #[must_use]
253 pub fn new(embeddings: Embeddings<'c, C>) -> Self {
254 Self { inner: embeddings }
255 }
256
257 pub async fn create(
267 &self,
268 request: CreateEmbeddingRequest,
269 ) -> Result<CreateEmbeddingResponse, OpenAIError> {
270 let embedding = LangfuseEmbedding::start("embedding");
271 embedding.set_input(&serde_json::json!(request.input));
272 embedding.set_model(&request.model);
273
274 match self.inner.create(request).await {
275 Ok(response) => {
276 embedding.set_model(&response.model);
278 embedding.set_usage(&UsageDetails {
279 input: Some(u64::from(response.usage.prompt_tokens)),
280 output: None,
281 total: Some(u64::from(response.usage.total_tokens)),
282 });
283 embedding.end();
285 Ok(response)
286 }
287 Err(err) => {
288 embedding.set_level(langfuse_core::types::SpanLevel::Error);
289 embedding.set_status_message(&err.to_string());
290 embedding.end();
291 Err(err)
292 }
293 }
294 }
295}
296
297pub fn observe_openai_embeddings<C: Config>(
300 client: &async_openai::Client<C>,
301) -> TracedEmbeddings<'_, C> {
302 TracedEmbeddings::new(client.embeddings())
303}