1use error_stack::{Report, ResultExt};
2use serde::Serialize;
3use serde_json::json;
4use smallvec::smallvec;
5use tracing::Span;
6
7use crate::{
8 database::logging::{CollectedProxiedResult, LogSender, ProxyLogEntry, ProxyLogEvent},
9 format::{
10 RequestInfo, ResponseInfo, SingleChatResponse, StreamingResponse,
11 StreamingResponseReceiver, StreamingResponseSender,
12 },
13 request::TryModelChoicesResult,
14 Error,
15};
16
17pub async fn handle_response(
18 current_span: Span,
19 log_entry: ProxyLogEvent,
20 global_start: tokio::time::Instant,
21 request_n: usize,
22 meta: TryModelChoicesResult,
23 chunk_rx: StreamingResponseReceiver,
24 output_tx: StreamingResponseSender,
25 log_tx: Option<LogSender>,
26) {
27 let response = collect_stream(
28 current_span.clone(),
29 log_entry,
30 global_start,
31 request_n,
32 &meta,
33 chunk_rx,
34 output_tx,
35 log_tx.as_ref(),
36 )
37 .await;
38 let Ok((response, info, mut log_entry)) = response else {
39 return;
41 };
42 let global_send_time = global_start.elapsed();
43 let this_send_time = meta.start_time.elapsed();
44 log_entry.latency = Some(this_send_time);
45
46 current_span.record("llm.total_latency", global_send_time.as_millis());
48
49 current_span.record(
50 "llm.completions",
51 response
52 .choices
53 .iter()
54 .filter_map(|c| c.message.content.as_deref())
55 .collect::<Vec<_>>()
56 .join("\n\n"),
57 );
58 current_span.record(
59 "llm.completions.raw",
60 serde_json::to_string(&response.choices).ok(),
61 );
62 current_span.record("llm.vendor", &meta.provider);
63 current_span.record("llm.response.model", &response.model);
64 current_span.record("llm.latency", this_send_time.as_millis());
65 current_span.record("llm.retries", meta.num_retries);
66 current_span.record("llm.rate_limited", meta.was_rate_limited);
67
68 let usage = response.usage.clone().unwrap_or_default();
69
70 current_span.record("llm.usage.prompt_tokens", usage.prompt_tokens);
71 current_span.record(
72 "llm.finish_reason",
73 response.choices.get(0).map(|c| c.finish_reason.as_str()),
74 );
75 current_span.record("llm.usage.completion_tokens", usage.completion_tokens);
76 let total_tokens = usage
77 .total_tokens
78 .unwrap_or_else(|| usage.prompt_tokens.unwrap_or(0) + usage.completion_tokens.unwrap_or(0));
79 current_span.record("llm.usage.total_tokens", total_tokens);
80
81 if let Some(log_tx) = log_tx {
82 log_entry.total_latency = Some(global_send_time);
83 log_entry.num_retries = Some(meta.num_retries);
84 log_entry.was_rate_limited = Some(meta.was_rate_limited);
85 log_entry.response = Some(CollectedProxiedResult {
86 body: response,
87 info,
88 provider: meta.provider,
89 });
90
91 log_tx
92 .send_async(smallvec![ProxyLogEntry::Proxied(Box::new(log_entry))])
93 .await
94 .ok();
95 }
96}
97
98async fn collect_stream(
100 current_span: Span,
101 log_entry: ProxyLogEvent,
102 global_start: tokio::time::Instant,
103 request_n: usize,
104 meta: &TryModelChoicesResult,
105 chunk_rx: StreamingResponseReceiver,
106 output_tx: StreamingResponseSender,
107 log_tx: Option<&LogSender>,
108) -> Result<(SingleChatResponse, ResponseInfo, ProxyLogEvent), ()> {
109 let mut response = SingleChatResponse::new_for_collection(request_n);
110
111 let mut res_stats = ResponseInfo {
112 model: String::new(),
113 meta: None,
114 };
115
116 while let Some(chunk) = chunk_rx.recv_async().await.ok() {
118 tracing::info!(?chunk, "Got chunk");
119 match &chunk {
120 Ok(StreamingResponse::Chunk(chunk)) => {
121 response.merge_delta(chunk);
122 }
123 Ok(StreamingResponse::ResponseInfo(i)) => {
124 res_stats = i.clone();
125 }
126 Ok(StreamingResponse::RequestInfo(_)) => {
127 }
130 Ok(StreamingResponse::Single(res)) => {
131 response = res.clone();
132 }
133 Err(e) => {
134 record_error(
135 log_entry,
136 e,
137 global_start,
138 meta.num_retries,
139 meta.was_rate_limited,
140 current_span,
141 log_tx,
142 )
143 .await;
144 output_tx.send_async(chunk).await.ok();
145 return Err(());
146 }
147 }
148
149 tracing::debug!(?chunk, "Sending chunk");
150 output_tx.send_async(chunk).await.ok();
151 }
152
153 Ok((response, res_stats, log_entry))
154}
155
156pub async fn record_error<E: std::fmt::Debug + std::fmt::Display>(
157 mut log_entry: ProxyLogEvent,
158 error: E,
159 send_start: tokio::time::Instant,
160 num_retries: u32,
161 was_rate_limited: bool,
162 current_span: Span,
163 log_tx: Option<&LogSender>,
164) {
165 tracing::error!(error.full=?error, "Request failed");
166
167 current_span.record("error", error.to_string());
168 current_span.record("llm.retries", num_retries);
169 current_span.record("llm.rate_limited", was_rate_limited);
170
171 if let Some(log_tx) = log_tx {
172 log_entry.total_latency = Some(send_start.elapsed());
173 log_entry.num_retries = Some(num_retries);
174 log_entry.was_rate_limited = Some(was_rate_limited);
175 log_entry.error = Some(json!(format!("{:?}", error)));
176 log_tx
177 .send_async(smallvec![ProxyLogEntry::Proxied(Box::new(log_entry))])
178 .await
179 .ok();
180 }
181}
182
183#[derive(Serialize, Debug)]
184pub struct CollectedResponse {
185 pub request_info: RequestInfo,
186 pub response_info: ResponseInfo,
187 pub was_streaming: bool,
188 pub num_chunks: usize,
189 pub response: SingleChatResponse,
190}
191
192pub async fn collect_response(
194 receiver: StreamingResponseReceiver,
195 request_n: usize,
196) -> Result<CollectedResponse, Report<Error>> {
197 let mut request_info = None;
198 let mut response_info = None;
199 let mut was_streaming = false;
200
201 let mut num_chunks = 0;
202 let mut response = SingleChatResponse::new_for_collection(request_n);
203
204 while let Ok(res) = receiver.recv_async().await {
205 tracing::debug!(?res, "Got response chunk");
206 match res.change_context(Error::ModelError)? {
207 StreamingResponse::RequestInfo(info) => {
208 debug_assert!(request_info.is_none(), "Saw multiple RequestInfo objects");
209 debug_assert_eq!(num_chunks, 0, "RequestInfo was not the first chunk");
210 request_info = Some(info);
211 }
212 StreamingResponse::ResponseInfo(info) => {
213 debug_assert!(response_info.is_none(), "Saw multiple ResponseInfo objects");
214 response_info = Some(info);
215 }
216 StreamingResponse::Single(res) => {
217 debug_assert_eq!(num_chunks, 0, "Saw more than one non-streaming chunk");
218 num_chunks += 1;
219 response = res;
220 }
221 StreamingResponse::Chunk(res) => {
222 was_streaming = true;
223 num_chunks += 1;
224 response.merge_delta(&res);
225 }
226 }
227 }
228
229 let request_info = request_info.ok_or(Error::MissingStreamInformation("request info"))?;
230 Ok(CollectedResponse {
231 response_info: response_info.unwrap_or_else(|| ResponseInfo {
232 meta: None,
233 model: request_info.model.clone(),
234 }),
235 request_info,
236 was_streaming,
237 num_chunks,
238 response,
239 })
240}