reagent-rs 0.2.7

A Rust library for building AI agents with MCP & custom tools
Documentation
use futures::{pin_mut, StreamExt};
use serde::Serialize;
use tracing::{error, span, Level, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;

use crate::{
    notifications::Token,
    services::llm::{
        message::Message,
        models::chat::{ChatResponse, ChatStreamChunk},
        InferenceClientError,
    },
    ChatRequest, InvocationError, InvocationRequest, NotificationHandler, ToolCall,
};

#[derive(Debug, Serialize)]
struct UsageDetails {
    prompt_tokens: i64,
    completion_tokens: i64,
    total_tokens: i64,
}

pub(super) async fn invoke_nonstreaming(
    invocation_request: InvocationRequest,
) -> Result<ChatResponse, InvocationError> {
    let InvocationRequest {
        strip_thinking,
        request,
        client,
        notification_channel,
    } = invocation_request;

    notification_channel
        .notify_prompt_request(request.clone())
        .await;

    let gen_span = set_telemetry_request_attributes(&request);
    let _guard = gen_span.enter();

    let raw = client.chat(request).await;

    let mut resp = match raw {
        Ok(resp) => resp,
        Err(e) => {
            notification_channel
                .notify_prompt_error(e.to_string())
                .await;
            extract_error_telemetry(&gen_span, e.to_string().as_str());
            return Err(e.into());
        }
    };

    extract_response_telemetry(&gen_span, &resp);

    notification_channel
        .notify_prompt_success(resp.clone())
        .await;

    if strip_thinking {
        if let Some(content) = resp.message.content.clone() {
            if let Some(after) = content.split("</think>").nth(1) {
                resp.message.content = Some(after.to_string());
            }
        }
    }

    Ok(resp)
}

pub(super) async fn invoke_streaming(
    invocation_request: InvocationRequest,
) -> Result<ChatResponse, InvocationError> {
    let InvocationRequest {
        strip_thinking,
        request,
        client,
        notification_channel,
    } = invocation_request;

    notification_channel
        .notify_prompt_request(request.clone())
        .await;

    let gen_span = set_telemetry_request_attributes(&request);
    let _guard = gen_span.enter();

    let stream = match client.chat_stream(request).await {
        Ok(s) => s,
        Err(e) => {
            notification_channel
                .notify_prompt_error(e.to_string())
                .await;
            extract_error_telemetry(&gen_span, e.to_string().as_str());
            return Err(e.into());
        }
    };

    pin_mut!(stream);

    let mut full_content = None;
    let mut latest_message: Option<Message> = None;
    let mut tool_calls: Option<Vec<ToolCall>> = None;
    let mut done_chunk: Option<ChatStreamChunk> = None;

    while let Some(chunk_res) = stream.next().await {
        let chunk = match chunk_res {
            Ok(c) => c,
            Err(e) => {
                notification_channel
                    .notify_prompt_error(e.to_string())
                    .await;
                return Err(e.into());
            }
        };

        if chunk.done {
            done_chunk = Some(chunk);
            break;
        }

        if let Some(msg) = &chunk.message {
            if let Some(calls) = &msg.tool_calls {
                match tool_calls.as_mut() {
                    Some(tool_call_vec) => tool_call_vec.extend(calls.clone()),
                    None => tool_calls = Some(calls.clone()),
                }
            }

            if let Some(tok) = &msg.content {
                notification_channel
                    .notify_token(Token {
                        tag: None,
                        value: tok.clone(),
                    })
                    .await;
                match full_content.as_mut() {
                    None => full_content = Some(tok.to_owned()),
                    Some(content) => content.push_str(tok),
                }
            }

            latest_message = Some(msg.clone());
        }
    }

    let Some(chunk) = done_chunk else {
        let error_message = "stream ended without a final `done` chunk";
        extract_error_telemetry(&gen_span, error_message);
        return Err(InferenceClientError::Api(error_message.into()).into());
    };

    let mut final_msg = latest_message.unwrap_or_else(|| Message::assistant(String::new()));
    final_msg.content = full_content;
    final_msg.tool_calls = tool_calls;

    if strip_thinking {
        if let Some(c) = &final_msg.content {
            if let Some(after) = c.split("</think>").nth(1) {
                final_msg.content = Some(after.to_string());
            }
        }
    }

    let mut response = ChatResponse {
        model: chunk.model,
        created_at: chunk.created_at,
        message: final_msg,
        done: chunk.done,
        done_reason: chunk.done_reason,
        total_duration: chunk.total_duration,
        load_duration: chunk.load_duration,
        prompt_eval_count: chunk.prompt_eval_count,
        prompt_eval_duration: chunk.prompt_eval_duration,
        eval_count: chunk.eval_count,
        eval_duration: chunk.eval_duration,
    };

    extract_response_telemetry(&gen_span, &response);

    notification_channel
        .notify_prompt_success(response.clone())
        .await;

    if strip_thinking {
        if let Some(content) = response.message.content.clone() {
            if let Some(after) = content.split("</think>").nth(1) {
                response.message.content = Some(after.to_string());
            }
        }
    }

    Ok(response)
}

fn extract_error_telemetry(gen_span: &Span, error_message: &str) {
    gen_span.set_attribute("otel.status_code", "ERROR");
    gen_span.set_attribute("error.message", error_message.to_string());
    gen_span.set_status(opentelemetry::trace::Status::Error {
        description: error_message.to_string().into(),
    });
    error!("stream ended without a final `done` chunk");
}

fn extract_response_telemetry(gen_span: &Span, response: &ChatResponse) {
    let prompt_tokens = response.prompt_eval_count.unwrap_or(0);
    let completion_tokens = response.eval_count.unwrap_or(0);
    let total_tokens = prompt_tokens + completion_tokens;

    let usage = UsageDetails {
        prompt_tokens: prompt_tokens as i64,
        completion_tokens: completion_tokens as i64,
        total_tokens: total_tokens as i64,
    };

    gen_span.set_attribute(
        "langfuse.observation.output",
        serde_json::to_string_pretty(&response.message.content)
            .unwrap_or(format!("{:#?}", response)),
    );
    gen_span.set_attribute("gen_ai.completion.0.role", "assistant");
    gen_span.set_attribute(
        "gen_ai.completion.0.content",
        response.message.content.clone().unwrap_or_default(),
    );

    gen_span.set_attribute(
        "langfuse.observation.usage_details",
        serde_json::to_string(&usage).unwrap_or(format!("{:#?}", usage)),
    );

    if let Some(duration) = response.total_duration {
        gen_span.set_attribute("llm.duration.total_s", (duration as f64) / 1_000_000_000.0);
    }
    if let Some(duration) = response.load_duration {
        gen_span.set_attribute("llm.duration.load_s", (duration as f64) / 1_000_000_000.0);
    }
    if let Some(duration) = response.prompt_eval_duration {
        gen_span.set_attribute(
            "llm.duration.prompt_eval_ms",
            (duration as f64) / 1_000_000.0,
        );
    }
    if let Some(duration) = response.eval_duration {
        gen_span.set_attribute("llm.duration.eval_ms", (duration as f64) / 1_000_000.0);
    }

    gen_span.set_attribute("otel.status_code", "OK");
}

fn set_telemetry_request_attributes(request: &ChatRequest) -> Span {
    let gen_span = span!(
        Level::INFO,
        "Chat Request",
        "langfuse.observation.type" = "generation",
        "langfuse.observation.model.name" = request.base.model.as_str(),
        "langfuse.observation.input" =
            serde_json::to_string(&request.messages).unwrap_or(format!("{:#?}", request.messages)),
        "langfuse.observation.model.parameters" = serde_json::to_string(&request.base.options)
            .unwrap_or(format!("{:#?}", request.base.options)),
    );

    gen_span
}