use std::pin::Pin;
use std::task::{Context, Poll};
use async_openai::Chat;
use async_openai::Embeddings;
use async_openai::config::Config;
use async_openai::error::OpenAIError;
use async_openai::types::chat::{
CreateChatCompletionRequest, CreateChatCompletionResponse, CreateChatCompletionStreamResponse,
};
use async_openai::types::embeddings::{CreateEmbeddingRequest, CreateEmbeddingResponse};
use futures::Stream;
use crate::parser::{self, ToolCallAccumulator};
use langfuse::{LangfuseEmbedding, LangfuseGeneration};
use langfuse_core::types::UsageDetails;
pub struct TracedChat<'c, C: Config> {
inner: Chat<'c, C>,
}
impl<C: Config> std::fmt::Debug for TracedChat<'_, C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TracedChat").finish()
}
}
impl<'c, C: Config> TracedChat<'c, C> {
#[must_use]
pub fn new(chat: Chat<'c, C>) -> Self {
Self { inner: chat }
}
pub async fn create(
&self,
request: CreateChatCompletionRequest,
) -> Result<CreateChatCompletionResponse, OpenAIError> {
let generation = LangfuseGeneration::start("chat-completion");
generation.set_input(&request);
match self.inner.create(request).await {
Ok(response) => {
generation.set_model(&parser::extract_model(&response));
if let Some(usage) = parser::extract_usage(&response) {
generation.set_usage(&usage);
}
generation.set_output(&parser::extract_output(&response));
if let Some(tool_calls) = parser::extract_tool_calls(&response) {
generation.set_tool_calls(&tool_calls);
}
generation.end();
Ok(response)
}
Err(err) => {
generation.set_level(langfuse_core::types::SpanLevel::Error);
generation.set_status_message(&err.to_string());
generation.end();
Err(err)
}
}
}
pub async fn create_stream(
&self,
request: CreateChatCompletionRequest,
) -> Result<TracedStream, OpenAIError> {
let generation = LangfuseGeneration::start("chat-completion");
generation.set_input(&request);
let stream = self.inner.create_stream(request).await?;
Ok(TracedStream::new(stream, generation))
}
}
pub fn observe_openai<C: Config>(client: &async_openai::Client<C>) -> TracedChat<'_, C> {
TracedChat::new(client.chat())
}
pub struct TracedStream {
inner:
Pin<Box<dyn Stream<Item = Result<CreateChatCompletionStreamResponse, OpenAIError>> + Send>>,
generation: Option<LangfuseGeneration>,
accumulated_content: String,
model: Option<String>,
first_chunk: bool,
tool_call_acc: ToolCallAccumulator,
}
impl std::fmt::Debug for TracedStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TracedStream")
.field("first_chunk", &self.first_chunk)
.field("model", &self.model)
.field("accumulated_content", &self.accumulated_content)
.finish_non_exhaustive()
}
}
impl TracedStream {
fn new(
inner: Pin<
Box<dyn Stream<Item = Result<CreateChatCompletionStreamResponse, OpenAIError>> + Send>,
>,
generation: LangfuseGeneration,
) -> Self {
Self {
inner,
generation: Some(generation),
accumulated_content: String::new(),
model: None,
first_chunk: true,
tool_call_acc: ToolCallAccumulator::new(),
}
}
fn finalize(&mut self) {
if let Some(generation) = self.generation.take() {
if let Some(model) = &self.model {
generation.set_model(model);
}
if !self.accumulated_content.is_empty() {
generation.set_output(&self.accumulated_content);
}
if self.tool_call_acc.has_calls() {
generation.set_tool_calls(&self.tool_call_acc.finalize());
}
generation.end();
}
}
}
impl Stream for TracedStream {
type Item = Result<CreateChatCompletionStreamResponse, OpenAIError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
match this.inner.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(chunk))) => {
if this.first_chunk {
this.first_chunk = false;
if let Some(span) = this.generation.as_ref() {
span.set_completion_start_time(&chrono::Utc::now());
}
}
if this.model.is_none() {
this.model = Some(chunk.model.clone());
}
if let Some(content) = parser::extract_stream_chunk_content(&chunk) {
this.accumulated_content.push_str(&content);
}
if let Some(usage) = parser::extract_stream_usage(&chunk)
&& let Some(span) = this.generation.as_ref()
{
span.set_usage(&usage);
}
this.tool_call_acc.accumulate(&chunk);
Poll::Ready(Some(Ok(chunk)))
}
Poll::Ready(Some(Err(err))) => {
if let Some(span) = this.generation.as_ref() {
span.set_level(langfuse_core::types::SpanLevel::Error);
span.set_status_message(&err.to_string());
}
this.finalize();
Poll::Ready(Some(Err(err)))
}
Poll::Ready(None) => {
this.finalize();
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
impl Drop for TracedStream {
fn drop(&mut self) {
self.finalize();
}
}
pub struct TracedEmbeddings<'c, C: Config> {
inner: Embeddings<'c, C>,
}
impl<C: Config> std::fmt::Debug for TracedEmbeddings<'_, C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TracedEmbeddings").finish()
}
}
impl<'c, C: Config> TracedEmbeddings<'c, C> {
#[must_use]
pub fn new(embeddings: Embeddings<'c, C>) -> Self {
Self { inner: embeddings }
}
pub async fn create(
&self,
request: CreateEmbeddingRequest,
) -> Result<CreateEmbeddingResponse, OpenAIError> {
let embedding = LangfuseEmbedding::start("embedding");
embedding.set_input(&serde_json::json!(request.input));
embedding.set_model(&request.model);
match self.inner.create(request).await {
Ok(response) => {
embedding.set_model(&response.model);
embedding.set_usage(&UsageDetails {
input: Some(u64::from(response.usage.prompt_tokens)),
output: None,
total: Some(u64::from(response.usage.total_tokens)),
});
embedding.end();
Ok(response)
}
Err(err) => {
embedding.set_level(langfuse_core::types::SpanLevel::Error);
embedding.set_status_message(&err.to_string());
embedding.end();
Err(err)
}
}
}
}
pub fn observe_openai_embeddings<C: Config>(
client: &async_openai::Client<C>,
) -> TracedEmbeddings<'_, C> {
TracedEmbeddings::new(client.embeddings())
}