chronicle_proxy/
response.rs

1use error_stack::{Report, ResultExt};
2use serde::Serialize;
3use serde_json::json;
4use smallvec::smallvec;
5use tracing::Span;
6
7use crate::{
8    database::logging::{CollectedProxiedResult, LogSender, ProxyLogEntry, ProxyLogEvent},
9    format::{
10        RequestInfo, ResponseInfo, SingleChatResponse, StreamingResponse,
11        StreamingResponseReceiver, StreamingResponseSender,
12    },
13    request::TryModelChoicesResult,
14    Error,
15};
16
17pub async fn handle_response(
18    current_span: Span,
19    log_entry: ProxyLogEvent,
20    global_start: tokio::time::Instant,
21    request_n: usize,
22    meta: TryModelChoicesResult,
23    chunk_rx: StreamingResponseReceiver,
24    output_tx: StreamingResponseSender,
25    log_tx: Option<LogSender>,
26) {
27    let response = collect_stream(
28        current_span.clone(),
29        log_entry,
30        global_start,
31        request_n,
32        &meta,
33        chunk_rx,
34        output_tx,
35        log_tx.as_ref(),
36    )
37    .await;
38    let Ok((response, info, mut log_entry)) = response else {
39        // Errors were already handled by collect_stream.
40        return;
41    };
42    let global_send_time = global_start.elapsed();
43    let this_send_time = meta.start_time.elapsed();
44    log_entry.latency = Some(this_send_time);
45
46    // In case of retries, this might be meaningfully different from the main latency.
47    current_span.record("llm.total_latency", global_send_time.as_millis());
48
49    current_span.record(
50        "llm.completions",
51        response
52            .choices
53            .iter()
54            .filter_map(|c| c.message.content.as_deref())
55            .collect::<Vec<_>>()
56            .join("\n\n"),
57    );
58    current_span.record(
59        "llm.completions.raw",
60        serde_json::to_string(&response.choices).ok(),
61    );
62    current_span.record("llm.vendor", &meta.provider);
63    current_span.record("llm.response.model", &response.model);
64    current_span.record("llm.latency", this_send_time.as_millis());
65    current_span.record("llm.retries", meta.num_retries);
66    current_span.record("llm.rate_limited", meta.was_rate_limited);
67
68    let usage = response.usage.clone().unwrap_or_default();
69
70    current_span.record("llm.usage.prompt_tokens", usage.prompt_tokens);
71    current_span.record(
72        "llm.finish_reason",
73        response.choices.get(0).map(|c| c.finish_reason.as_str()),
74    );
75    current_span.record("llm.usage.completion_tokens", usage.completion_tokens);
76    let total_tokens = usage
77        .total_tokens
78        .unwrap_or_else(|| usage.prompt_tokens.unwrap_or(0) + usage.completion_tokens.unwrap_or(0));
79    current_span.record("llm.usage.total_tokens", total_tokens);
80
81    if let Some(log_tx) = log_tx {
82        log_entry.total_latency = Some(global_send_time);
83        log_entry.num_retries = Some(meta.num_retries);
84        log_entry.was_rate_limited = Some(meta.was_rate_limited);
85        log_entry.response = Some(CollectedProxiedResult {
86            body: response,
87            info,
88            provider: meta.provider,
89        });
90
91        log_tx
92            .send_async(smallvec![ProxyLogEntry::Proxied(Box::new(log_entry))])
93            .await
94            .ok();
95    }
96}
97
98/// Internal stream collection that saves the information for logging.
99async fn collect_stream(
100    current_span: Span,
101    log_entry: ProxyLogEvent,
102    global_start: tokio::time::Instant,
103    request_n: usize,
104    meta: &TryModelChoicesResult,
105    chunk_rx: StreamingResponseReceiver,
106    output_tx: StreamingResponseSender,
107    log_tx: Option<&LogSender>,
108) -> Result<(SingleChatResponse, ResponseInfo, ProxyLogEvent), ()> {
109    let mut response = SingleChatResponse::new_for_collection(request_n);
110
111    let mut res_stats = ResponseInfo {
112        model: String::new(),
113        meta: None,
114    };
115
116    // Collect the message chunks so we can log the result, while also passing them on to the output channel.
117    while let Some(chunk) = chunk_rx.recv_async().await.ok() {
118        tracing::info!(?chunk, "Got chunk");
119        match &chunk {
120            Ok(StreamingResponse::Chunk(chunk)) => {
121                response.merge_delta(chunk);
122            }
123            Ok(StreamingResponse::ResponseInfo(i)) => {
124                res_stats = i.clone();
125            }
126            Ok(StreamingResponse::RequestInfo(_)) => {
127                // Don't need to handle RequestInfo since we've already incorporated its
128                // information into `log_entry`.
129            }
130            Ok(StreamingResponse::Single(res)) => {
131                response = res.clone();
132            }
133            Err(e) => {
134                record_error(
135                    log_entry,
136                    e,
137                    global_start,
138                    meta.num_retries,
139                    meta.was_rate_limited,
140                    current_span,
141                    log_tx,
142                )
143                .await;
144                output_tx.send_async(chunk).await.ok();
145                return Err(());
146            }
147        }
148
149        tracing::debug!(?chunk, "Sending chunk");
150        output_tx.send_async(chunk).await.ok();
151    }
152
153    Ok((response, res_stats, log_entry))
154}
155
156pub async fn record_error<E: std::fmt::Debug + std::fmt::Display>(
157    mut log_entry: ProxyLogEvent,
158    error: E,
159    send_start: tokio::time::Instant,
160    num_retries: u32,
161    was_rate_limited: bool,
162    current_span: Span,
163    log_tx: Option<&LogSender>,
164) {
165    tracing::error!(error.full=?error, "Request failed");
166
167    current_span.record("error", error.to_string());
168    current_span.record("llm.retries", num_retries);
169    current_span.record("llm.rate_limited", was_rate_limited);
170
171    if let Some(log_tx) = log_tx {
172        log_entry.total_latency = Some(send_start.elapsed());
173        log_entry.num_retries = Some(num_retries);
174        log_entry.was_rate_limited = Some(was_rate_limited);
175        log_entry.error = Some(json!(format!("{:?}", error)));
176        log_tx
177            .send_async(smallvec![ProxyLogEntry::Proxied(Box::new(log_entry))])
178            .await
179            .ok();
180    }
181}
182
183#[derive(Serialize, Debug)]
184pub struct CollectedResponse {
185    pub request_info: RequestInfo,
186    pub response_info: ResponseInfo,
187    pub was_streaming: bool,
188    pub num_chunks: usize,
189    pub response: SingleChatResponse,
190}
191
192/// Collect a stream contents into a single response
193pub async fn collect_response(
194    receiver: StreamingResponseReceiver,
195    request_n: usize,
196) -> Result<CollectedResponse, Report<Error>> {
197    let mut request_info = None;
198    let mut response_info = None;
199    let mut was_streaming = false;
200
201    let mut num_chunks = 0;
202    let mut response = SingleChatResponse::new_for_collection(request_n);
203
204    while let Ok(res) = receiver.recv_async().await {
205        tracing::debug!(?res, "Got response chunk");
206        match res.change_context(Error::ModelError)? {
207            StreamingResponse::RequestInfo(info) => {
208                debug_assert!(request_info.is_none(), "Saw multiple RequestInfo objects");
209                debug_assert_eq!(num_chunks, 0, "RequestInfo was not the first chunk");
210                request_info = Some(info);
211            }
212            StreamingResponse::ResponseInfo(info) => {
213                debug_assert!(response_info.is_none(), "Saw multiple ResponseInfo objects");
214                response_info = Some(info);
215            }
216            StreamingResponse::Single(res) => {
217                debug_assert_eq!(num_chunks, 0, "Saw more than one non-streaming chunk");
218                num_chunks += 1;
219                response = res;
220            }
221            StreamingResponse::Chunk(res) => {
222                was_streaming = true;
223                num_chunks += 1;
224                response.merge_delta(&res);
225            }
226        }
227    }
228
229    let request_info = request_info.ok_or(Error::MissingStreamInformation("request info"))?;
230    Ok(CollectedResponse {
231        response_info: response_info.unwrap_or_else(|| ResponseInfo {
232            meta: None,
233            model: request_info.model.clone(),
234        }),
235        request_info,
236        was_streaming,
237        num_chunks,
238        response,
239    })
240}