chronicle_proxy/
lib.rs

1//! Chronicle LLM Proxy and Observability tool.
2//! This is the implementation of the proxy which can be embedded into a Rust application.
3//! For other uses you may want to try the full-fledged API application in the chronicle-api crate.
4
5use std::{borrow::Cow, fmt::Debug, str::FromStr, sync::Arc, time::Duration};
6
7pub mod builder;
8pub mod config;
9pub mod database;
10pub mod error;
11pub mod format;
12mod provider_lookup;
13pub mod providers;
14pub mod request;
15mod response;
16mod streaming;
17#[cfg(test)]
18mod testing;
19pub mod workflow_events;
20
21use builder::ProxyBuilder;
22use config::{AliasConfig, ApiKeyConfig};
23use database::logging::{LogSender, ProxyLogEntry, ProxyLogEvent};
24pub use error::Error;
25use error_stack::{Report, ResultExt};
26use format::{
27    ChatRequest, RequestInfo, SingleChatResponse, StreamingResponse, StreamingResponseReceiver,
28    StreamingResponseSender,
29};
30use http::HeaderMap;
31use provider_lookup::{ModelLookupResult, ProviderLookup};
32use providers::ChatModelProvider;
33use request::RetryOptions;
34pub use response::{collect_response, CollectedResponse};
35use response::{handle_response, record_error};
36use serde::{de::DeserializeOwned, Deserialize, Serialize};
37use serde_with::{serde_as, DurationMilliSeconds};
38use smallvec::{smallvec, SmallVec};
39use tracing::{instrument, Span};
40use uuid::Uuid;
41use workflow_events::{EventPayload, WorkflowEvent};
42
43use crate::request::try_model_choices;
44
45pub type AnyChatModelProvider = Arc<dyn ChatModelProvider>;
46
47#[derive(Debug, Serialize)]
48pub struct ProxiedChatResponseMeta {
49    /// A UUID assigned by Chronicle to the request, which is linked to the logged information.
50    /// This is different from the `id` returned at the top level of the [ChatResponse], which
51    /// comes from the provider itself.
52    pub id: Uuid,
53    /// Which provider was used for the request.
54    pub provider: String,
55    pub response_meta: Option<serde_json::Value>,
56    pub was_rate_limited: bool,
57}
58
59#[derive(Debug, Serialize)]
60pub struct ProxiedChatResponse {
61    #[serde(flatten)]
62    pub response: SingleChatResponse,
63    pub meta: ProxiedChatResponseMeta,
64}
65
66/// The Chronicle proxy object
67#[derive(Debug)]
68pub struct Proxy {
69    log_tx: Option<LogSender>,
70    log_task: Option<tokio::task::JoinHandle<()>>,
71    lookup: ProviderLookup,
72    default_timeout: Option<Duration>,
73}
74
75impl Proxy {
76    /// Create a builder for the proxy
77    pub fn builder() -> ProxyBuilder {
78        ProxyBuilder::new()
79    }
80
81    /// Record an event to the database. This lets you have your LLM request events and other
82    /// events in the same database table.
83    pub async fn record_event(&self, body: EventPayload) -> Uuid {
84        let id = Uuid::now_v7();
85
86        let Some(log_tx) = &self.log_tx else {
87            return id;
88        };
89
90        let log_entry = ProxyLogEntry::Proxied(Box::new(ProxyLogEvent::from_payload(id, body)));
91
92        log_tx.send_async(smallvec![log_entry]).await.ok();
93
94        id
95    }
96
97    /// Record a step event to the database
98    pub async fn record_workflow_event(&self, event: WorkflowEvent) {
99        let Some(log_tx) = &self.log_tx else {
100            return;
101        };
102
103        log_tx
104            .send_async(smallvec![ProxyLogEntry::Workflow(event)])
105            .await
106            .ok();
107    }
108
109    /// Record multiple events, steps, and run updates
110    pub async fn record_event_batch(&self, events: impl Into<SmallVec<[WorkflowEvent; 1]>>) {
111        let Some(log_tx) = &self.log_tx else {
112            return;
113        };
114
115        let events = events
116            .into()
117            .into_iter()
118            .map(ProxyLogEntry::Workflow)
119            .collect::<_>();
120
121        log_tx.send_async(events).await.ok();
122    }
123
124    pub async fn send(
125        &self,
126        options: ProxyRequestOptions,
127        body: ChatRequest,
128    ) -> Result<StreamingResponseReceiver, Report<Error>> {
129        let (chunk_tx, chunk_rx) = if body.stream {
130            flume::unbounded()
131        } else {
132            flume::bounded(5)
133        };
134
135        let models = self.lookup.find_model_and_provider(&options, &body)?;
136
137        if models.choices.is_empty() {
138            return Err(Report::new(Error::AliasEmpty(models.alias)));
139        }
140
141        let parent_span = tracing::Span::current();
142        let log_tx = self.log_tx.clone();
143        let default_timeout = self.default_timeout;
144        tokio::task::spawn(async move {
145            Self::send_request(
146                parent_span,
147                options,
148                models,
149                body,
150                default_timeout,
151                chunk_tx,
152                log_tx,
153            )
154            .await
155        });
156        Ok(chunk_rx)
157    }
158
159    /// Send a request, choosing the provider based on the requested `model` and `provider`.
160    ///
161    /// `options.models` can be used to specify a list of models and providers to use.
162    /// `options.model` will be used next to choose a model to use. This and body["model"] can be
163    /// an alias name.
164    /// `options.provider` can be used to choose a specific provider if the model is not an alias.
165    /// `body["model"]` is used if options.model is empty.
166    #[instrument(
167        name = "llm.send_request",
168        parent=&parent_span,
169        skip(options),
170        fields(
171            error,
172            llm.options=serde_json::to_string(&options).ok(),
173            llm.item_id,
174            llm.finish_reason,
175            llm.latency,
176            llm.total_latency,
177            llm.retries,
178            llm.rate_limited,
179            llm.status_code,
180            llm.meta.application = options.metadata.application,
181            llm.meta.environment = options.metadata.environment,
182            llm.meta.organization_id = options.metadata.organization_id,
183            llm.meta.project_id = options.metadata.project_id,
184            llm.meta.user_id = options.metadata.user_id,
185            llm.meta.workflow_id = options.metadata.workflow_id,
186            llm.meta.workflow_name = options.metadata.workflow_name,
187            llm.meta.run_id = options.metadata.run_id.map(|u| u.to_string()),
188            llm.meta.step = options.metadata.step_id.map(|u| u.to_string()),
189            llm.meta.step_index = options.metadata.step_index,
190            llm.meta.prompt_id = options.metadata.prompt_id,
191            llm.meta.prompt_version = options.metadata.prompt_version,
192            llm.meta.extra,
193            llm.meta.internal_organization_id = options.internal_metadata.organization_id,
194            llm.meta.internal_project_id = options.internal_metadata.project_id,
195            llm.meta.internal_user_id = options.internal_metadata.user_id,
196            // The fields below are using the OpenLLMetry field names
197            llm.vendor,
198            // This will work once https://github.com/tokio-rs/tracing/pull/2925 is merged
199            // llm.request.type = "chat",
200            llm.request.model = body.model,
201            llm.prompts,
202            llm.prompts.raw = serde_json::to_string(&body.messages).ok(),
203            llm.request.max_tokens = body.max_tokens,
204            llm.response.model,
205            llm.usage.prompt_tokens,
206            llm.usage.completion_tokens,
207            llm.usage.total_tokens,
208            llm.completions,
209            llm.completions.raw,
210            llm.temperature = body.temperature,
211            llm.top_p = body.top_p,
212            llm.frequency_penalty = body.frequency_penalty,
213            llm.presence_penalty = body.presence_penalty,
214            llm.chat.stop_sequences,
215            llm.user = body.user,
216        )
217    )]
218    async fn send_request(
219        parent_span: Span,
220        options: ProxyRequestOptions,
221        models: ModelLookupResult,
222        body: ChatRequest,
223        default_timeout: Option<Duration>,
224        output_tx: StreamingResponseSender,
225        log_tx: Option<LogSender>,
226    ) {
227        let id = uuid::Uuid::now_v7();
228        let current_span = tracing::Span::current();
229        current_span.record("llm.item_id", id.to_string());
230        if !body.stop.is_empty() {
231            current_span.record(
232                "llm.chat.stop_sequences",
233                serde_json::to_string(&body.stop).ok(),
234            );
235        }
236
237        if let Some(extra) = options.metadata.extra.as_ref().filter(|e| !e.is_empty()) {
238            current_span.record("llm.meta.extra", &serde_json::to_string(extra).ok());
239        }
240
241        let messages_field = if body.messages.len() > 1 {
242            Some(Cow::Owned(
243                body.messages
244                    .iter()
245                    .filter_map(|m| {
246                        let Some(content) = m.content.as_deref() else {
247                            return None;
248                        };
249
250                        Some(format!(
251                            "{}: {}",
252                            m.name.as_deref().or(m.role.as_deref()).unwrap_or_default(),
253                            content
254                        ))
255                    })
256                    .collect::<Vec<_>>()
257                    .join("\n\n"),
258            ))
259        } else {
260            body.messages
261                .get(0)
262                .and_then(|m| m.content.as_deref().map(Cow::Borrowed))
263        };
264        current_span.record("llm.prompts", messages_field.as_deref());
265
266        if models.choices.len() == 1 {
267            // If there's just one provider we can record this in advance to get it even in case of
268            // error.
269            current_span.record("llm.vendor", models.choices[0].provider.name());
270        }
271
272        tracing::info!(?body, "Starting request");
273
274        let retry = options.retry.clone().unwrap_or_default();
275
276        let (chunk_tx, chunk_rx) = flume::bounded(5);
277
278        let timestamp = chrono::Utc::now();
279        let global_start = tokio::time::Instant::now();
280        let response = try_model_choices(
281            models,
282            options.override_url.clone(),
283            retry,
284            options
285                .timeout
286                .or(default_timeout)
287                .unwrap_or_else(|| Duration::from_millis(60_000)),
288            body.clone(),
289            chunk_tx,
290        )
291        .await;
292
293        let n = body.n.unwrap_or(1) as usize;
294
295        // Fill in what we can now, the rest will be filled in once the response is done.
296        let log_entry = ProxyLogEvent {
297            id,
298            event_type: Cow::Borrowed("chronicle_llm_request"),
299            timestamp,
300            request: Some(body),
301            response: None,
302            total_latency: None,
303            latency: None,
304            num_retries: None,
305            was_rate_limited: None,
306            error: None,
307            options,
308        };
309
310        match response {
311            Ok(res) => {
312                output_tx
313                    .send_async(Ok(StreamingResponse::RequestInfo(RequestInfo {
314                        id,
315                        provider: res.provider.clone(),
316                        model: res.model.clone(),
317                        num_retries: res.num_retries,
318                        was_rate_limited: res.was_rate_limited,
319                    })))
320                    .await
321                    .ok();
322                handle_response(
323                    current_span,
324                    log_entry,
325                    global_start,
326                    n,
327                    res,
328                    chunk_rx,
329                    output_tx,
330                    log_tx,
331                )
332                .await;
333            }
334            Err(e) => {
335                record_error(
336                    log_entry,
337                    &e.error,
338                    global_start,
339                    e.num_retries,
340                    e.was_rate_limited,
341                    current_span,
342                    log_tx.as_ref(),
343                )
344                .await;
345                output_tx.send_async(Err(e.error)).await.ok();
346            }
347        }
348    }
349
350    /// Add a provider to the system. This will replace any existing provider with the same `name`.
351    pub fn set_provider(&self, provider: Arc<dyn ChatModelProvider>) {
352        self.lookup.set_provider(provider);
353    }
354
355    /// Remove a provider. Any aliases that reference this provider's name will stop working.
356    pub fn remove_provider(&self, name: &str) {
357        self.lookup.remove_provider(name);
358    }
359
360    /// Add an alias to the system. This will replace any existing alias with the same `name`.
361    pub fn set_alias(&self, alias: AliasConfig) {
362        self.lookup.set_alias(alias);
363    }
364
365    /// Remove an alias
366    pub fn remove_alias(&self, name: &str) {
367        self.lookup.remove_alias(name);
368    }
369
370    /// Add an API key to the system. This will replace any existing API key with the same `name`.
371    pub fn set_api_key(&self, api_key: ApiKeyConfig) {
372        self.lookup.set_api_key(api_key);
373    }
374
375    /// Remove an API key. Any aliases that reference this API key's name will stop working.
376    pub fn remove_api_key(&self, name: &str) {
377        self.lookup.remove_api_key(name);
378    }
379
380    /// Shutdown the proxy, making sure to write any queued logging events
381    pub async fn shutdown(&mut self) {
382        let log_tx = self.log_tx.take();
383        drop(log_tx);
384        let log_task = self.log_task.take();
385        if let Some(log_task) = log_task {
386            log_task.await.ok();
387        }
388    }
389
390    /// Validate the loaded configuration, and return a list of problems found.
391    // todo this doesn't do anything yet
392    fn validate(&self) -> Vec<String> {
393        self.lookup.validate()
394    }
395}
396
397#[derive(Debug, Clone, Serialize, Deserialize)]
398pub struct ModelAndProvider {
399    pub model: String,
400    pub provider: String,
401    /// Supply an API key.
402    pub api_key: Option<String>,
403    /// Get the API key from a preconfigured key
404    pub api_key_name: Option<String>,
405}
406
407#[serde_as]
408#[derive(Debug, Default, Serialize, Deserialize)]
409pub struct ProxyRequestOptions {
410    /// Override the model from the request body or select an alias.
411    /// This can also be set by passing the x-chronicle-model HTTP header.
412    pub model: Option<String>,
413    /// Choose a specific provider to use. This can also be set by passing the
414    /// x-chronicle-provider HTTP header.
415    pub provider: Option<String>,
416    /// Force the provider to use a specific URL instead of its default. This can also be set
417    /// by passing the x-chronicle-override-url HTTP header.
418    pub override_url: Option<String>,
419    /// An API key to pass to the provider. This can also be set by passing the
420    /// x-chronicle-provider-api-key HTTP header.
421    pub api_key: Option<String>,
422    /// Supply multiple provider/model choices, which will be tried in order.
423    /// If this is provided, the `model`, `provider`, and `api_key` fields are ignored.
424    /// This field can not reference model aliases.
425    /// This can also be set by passing the x-chronicle-models HTTP header using JSON syntax.
426    #[serde(default)]
427    pub models: Vec<ModelAndProvider>,
428    /// When using `models` to supply multiple choices, start at a random choice instead of the
429    /// first one.
430    /// This can also be set by passing the x-chronicle-random-choice HTTP header.
431    pub random_choice: Option<bool>,
432    #[serde_as(as = "Option<DurationMilliSeconds>")]
433    /// Customize the proxy's timeout when waiting for the request.
434    /// This can also be set by passing the x-chronicle-timeout HTTP header.
435    pub timeout: Option<std::time::Duration>,
436    /// Customize the retry behavior. This can also be set by passing the
437    /// x-chronicle-retry HTTP header.
438    pub retry: Option<RetryOptions>,
439
440    /// Metadata to record for the request
441    #[serde(default)]
442    pub metadata: ProxyRequestMetadata,
443
444    /// Internal user authentication metadata for the request. This can be useful if you have a
445    /// different set of internal users and organizations than what gets recorded in `metadata`.
446    #[serde(skip, default)]
447    pub internal_metadata: ProxyRequestInternalMetadata,
448}
449
450impl ProxyRequestOptions {
451    pub fn merge_request_headers(&mut self, headers: &HeaderMap) -> Result<(), Report<Error>> {
452        get_header_str(&mut self.api_key, headers, "x-chronicle-provider-api-key");
453        get_header_str(&mut self.provider, headers, "x-chronicle-provider");
454        get_header_str(&mut self.model, headers, "x-chronicle-model");
455        get_header_str(&mut self.override_url, headers, "x-chronicle-override-url");
456
457        let models_header = headers
458            .get("x-chronicle-models")
459            .map(|s| serde_json::from_slice::<Vec<ModelAndProvider>>(s.as_bytes()))
460            .transpose()
461            .change_context_lazy(|| {
462                Error::ReadingHeader(
463                    "x-chronicle-models".to_string(),
464                    "Array of ModelAndProvider",
465                )
466            })?;
467        if let Some(models_header) = models_header {
468            self.models = models_header;
469        }
470
471        get_header_t(
472            &mut self.random_choice,
473            headers,
474            "x-chronicle-random-choice",
475            "boolean",
476        )?;
477        get_header_json(&mut self.retry, headers, "x-chronicle-retry")?;
478
479        let timeout = headers
480            .get("x-chronicle-timeout")
481            .and_then(|s| s.to_str().ok())
482            .map(|s| s.parse::<u64>())
483            .transpose()
484            .change_context_lazy(|| {
485                Error::ReadingHeader("x-chronicle-timeout".to_string(), "integer")
486            })?
487            .map(|s| std::time::Duration::from_millis(s));
488        if timeout.is_some() {
489            self.timeout = timeout;
490        }
491
492        self.metadata.merge_request_headers(headers)?;
493
494        Ok(())
495    }
496
497    /// Merge values from `other`, when the values in the current object are not set.
498    pub fn merge_from(&mut self, other: &Self) {
499        if self.model.is_none() {
500            self.model = other.model.clone();
501        }
502        if self.provider.is_none() {
503            self.provider = other.provider.clone();
504        }
505        if self.override_url.is_none() {
506            self.override_url = other.override_url.clone();
507        }
508        if self.api_key.is_none() {
509            self.api_key = other.api_key.clone();
510        }
511        if self.models.is_empty() {
512            self.models = other.models.clone();
513        }
514        if self.random_choice.is_none() {
515            self.random_choice = other.random_choice;
516        }
517        if self.timeout.is_none() {
518            self.timeout = other.timeout;
519        }
520        if self.retry.is_none() {
521            self.retry = other.retry.clone();
522        }
523        self.metadata.merge_from(&other.metadata);
524        self.internal_metadata.merge_from(&other.internal_metadata);
525    }
526}
527
528#[derive(Debug, Serialize, Deserialize, Default)]
529/// Metadata about the internal source of this request. Mostly useful for multi-tenant
530/// scenarios where one proxy server is handling requests from multiple unrelated applications.
531pub struct ProxyRequestInternalMetadata {
532    /// The internal organiztion that the request belongs to
533    pub organization_id: Option<String>,
534    /// The internal project that the request belongs to
535    pub project_id: Option<String>,
536    /// The internal user ID that the request belongs to
537    pub user_id: Option<String>,
538}
539
540impl ProxyRequestInternalMetadata {
541    pub fn merge_from(&mut self, other: &Self) {
542        if self.organization_id.is_none() {
543            self.organization_id = other.organization_id.clone();
544        }
545        if self.project_id.is_none() {
546            self.project_id = other.project_id.clone();
547        }
548        if self.user_id.is_none() {
549            self.user_id = other.user_id.clone();
550        }
551    }
552}
553
554#[derive(Debug, Serialize, Deserialize, Default)]
555/// Metadata about the request and how it fits into the system as a whole. All of these
556/// fields are optional, and the `extra` field can be used to add anything else that useful
557/// for your use case.
558pub struct ProxyRequestMetadata {
559    /// The application making this call. This can also be set by passing the
560    /// x-chronicle-application HTTP header.
561    pub application: Option<String>,
562    /// The environment the application is running in. This can also be set by passing the
563    /// x-chronicle-environment HTTP header.
564    pub environment: Option<String>,
565    /// The organization related to the request. This can also be set by passing the
566    /// x-chronicle-organization-id HTTP header.
567    pub organization_id: Option<String>,
568    /// The project related to the request. This can also be set by passing the
569    /// x-chronicle-project-id HTTP header.
570    pub project_id: Option<String>,
571    /// The id of the user that triggered the request. This can also be set by passing the
572    /// x-chronicle-user-id HTTP header.
573    pub user_id: Option<String>,
574    /// The id of the workflow that this request belongs to. This can also be set by passing the
575    /// x-chronicle-workflow-id HTTP header.
576    pub workflow_id: Option<String>,
577    /// A readable name of the workflow that this request belongs to. This can also be set by
578    /// passing the x-chronicle-workflow-name HTTP header.
579    pub workflow_name: Option<String>,
580    /// The id of of the specific run that this request belongs to. This can also be set by
581    /// passing the x-chronicle-run-id HTTP header.
582    pub run_id: Option<Uuid>,
583    /// The name of the workflow step. This can also be set by passing the
584    /// x-chronicle-step-id HTTP header.
585    pub step_id: Option<Uuid>,
586    /// The index of the step within the workflow. This can also be set by passing the
587    /// x-chronicle-step-index HTTP header.
588    pub step_index: Option<u32>,
589    /// A unique ID for this prompt. This can also be set by passing the
590    /// x-chronicle-prompt-id HTTP header.
591    pub prompt_id: Option<String>,
592    /// The version of this prompt. This can also be set by passing the
593    /// x-chronicle-prompt-version HTTP header.
594    pub prompt_version: Option<u32>,
595
596    /// Any other metadata to include. When passing this in the request body, any unknown fields
597    /// are collected here. This can also be set by passing a JSON object to the
598    /// x-chronicle-extra-meta HTTP header.
599    #[serde(flatten)]
600    pub extra: Option<serde_json::Map<String, serde_json::Value>>,
601}
602
603impl ProxyRequestMetadata {
604    pub fn merge_request_headers(&mut self, headers: &HeaderMap) -> Result<(), Report<Error>> {
605        get_header_str(&mut self.application, headers, "x-chronicle-application");
606        get_header_str(&mut self.environment, headers, "x-chronicle-environment");
607        get_header_str(
608            &mut self.organization_id,
609            headers,
610            "x-chronicle-organization-id",
611        );
612        get_header_str(&mut self.project_id, headers, "x-chronicle-project-id");
613        get_header_str(&mut self.user_id, headers, "x-chronicle-user-id");
614        get_header_str(&mut self.workflow_id, headers, "x-chronicle-workflow-id");
615        get_header_str(
616            &mut self.workflow_name,
617            headers,
618            "x-chronicle-workflow-name",
619        );
620        get_header_t(&mut self.run_id, headers, "x-chronicle-run-id", "UUID")?;
621        get_header_t(&mut self.step_id, headers, "x-chronicle-step-id", "UUID")?;
622        get_header_t(
623            &mut self.step_index,
624            headers,
625            "x-chronicle-step-index",
626            "integer",
627        )?;
628        get_header_str(&mut self.prompt_id, headers, "x-chronicle-prompt-id");
629        get_header_t(
630            &mut self.prompt_version,
631            headers,
632            "x-chronicle-prompt-version",
633            "integer",
634        )?;
635        get_header_json(&mut self.extra, headers, "x-chronicle-extra-meta")?;
636        Ok(())
637    }
638
639    /// Merge values from `other`, when the values in the current object are not set.
640    pub fn merge_from(&mut self, other: &Self) {
641        if self.application.is_none() {
642            self.application = other.application.clone();
643        }
644        if self.environment.is_none() {
645            self.environment = other.environment.clone();
646        }
647        if self.organization_id.is_none() {
648            self.organization_id = other.organization_id.clone();
649        }
650        if self.project_id.is_none() {
651            self.project_id = other.project_id.clone();
652        }
653        if self.user_id.is_none() {
654            self.user_id = other.user_id.clone();
655        }
656        if self.workflow_id.is_none() {
657            self.workflow_id = other.workflow_id.clone();
658        }
659        if self.workflow_name.is_none() {
660            self.workflow_name = other.workflow_name.clone();
661        }
662        if self.run_id.is_none() {
663            self.run_id = other.run_id;
664        }
665        if self.step_id.is_none() {
666            self.step_id = other.step_id;
667        }
668        if self.step_index.is_none() {
669            self.step_index = other.step_index;
670        }
671        if self.prompt_id.is_none() {
672            self.prompt_id = other.prompt_id.clone();
673        }
674        if self.prompt_version.is_none() {
675            self.prompt_version = other.prompt_version;
676        }
677        if self.extra.is_none() {
678            self.extra = other.extra.clone();
679        }
680    }
681}
682
683fn get_header_str(body_value: &mut Option<String>, headers: &HeaderMap, key: &str) {
684    if body_value.is_some() {
685        return;
686    }
687
688    let value = headers
689        .get(key)
690        .and_then(|s| s.to_str().ok())
691        .map(|s| s.to_string());
692
693    if value.is_some() {
694        *body_value = value;
695    }
696}
697
698fn get_header_t<T>(
699    body_value: &mut Option<T>,
700    headers: &HeaderMap,
701    key: &str,
702    expected_format: &'static str,
703) -> Result<(), Report<Error>>
704where
705    T: FromStr,
706    T::Err: std::error::Error + Send + Sync + 'static,
707{
708    if body_value.is_some() {
709        return Ok(());
710    }
711
712    let value = headers
713        .get(key)
714        .and_then(|s| s.to_str().ok())
715        .map(|s| s.parse::<T>())
716        .transpose()
717        .change_context_lazy(|| Error::ReadingHeader(key.to_string(), expected_format))?;
718
719    if value.is_some() {
720        *body_value = value;
721    }
722
723    Ok(())
724}
725
726fn get_header_json<T: DeserializeOwned>(
727    body_value: &mut Option<T>,
728    headers: &HeaderMap,
729    key: &str,
730) -> Result<(), Report<Error>> {
731    if body_value.is_some() {
732        return Ok(());
733    }
734
735    let value = headers
736        .get(key)
737        .and_then(|s| s.to_str().ok())
738        .map(|s| serde_json::from_str(s))
739        .transpose()
740        .change_context_lazy(|| Error::ReadingHeader(key.to_string(), "JSON value"))?;
741
742    if value.is_some() {
743        *body_value = value;
744    }
745
746    Ok(())
747}
748
749#[cfg(test)]
750mod test {
751    use std::collections::BTreeMap;
752
753    use serde_json::json;
754    use uuid::Uuid;
755    use wiremock::{
756        matchers::{method, path},
757        Mock, ResponseTemplate,
758    };
759
760    use crate::{
761        collect_response,
762        config::CustomProviderConfig,
763        format::{
764            ChatChoice, ChatChoiceDelta, ChatMessage, ChatRequest, ChatResponse,
765            StreamingChatResponse, UsageResponse,
766        },
767        providers::custom::{OpenAiRequestFormatOptions, ProviderRequestFormat},
768        ProxyRequestMetadata,
769    };
770
771    #[test]
772    /// Make sure extra flattening works as expected
773    fn deserialize_meta() {
774        let step = Uuid::now_v7();
775        let test_value = json!({
776            "application": "abc",
777            "another": "value",
778            "step_id": step,
779            "third": "fourth",
780        });
781
782        let value: ProxyRequestMetadata =
783            serde_json::from_value(test_value).expect("deserializing");
784
785        println!("{value:#?}");
786        assert_eq!(value.application, Some("abc".to_string()));
787        assert_eq!(value.step_id, Some(step));
788        assert_eq!(
789            value.extra.as_ref().unwrap().get("another").unwrap(),
790            &json!("value")
791        );
792        assert_eq!(
793            value.extra.as_ref().unwrap().get("third").unwrap(),
794            &json!("fourth")
795        );
796    }
797
798    #[tokio::test]
799    async fn call_provider_nonstreaming() {
800        let mock_server = wiremock::MockServer::start().await;
801        Mock::given(method("POST"))
802            .and(path("/v1/chat/completions"))
803            .respond_with(ResponseTemplate::new(200).set_body_json(ChatResponse {
804                created: 1,
805                model: None,
806                system_fingerprint: None,
807                usage: Some(UsageResponse {
808                    prompt_tokens: Some(1),
809                    completion_tokens: Some(1),
810                    total_tokens: Some(2),
811                }),
812                choices: vec![ChatChoice {
813                    index: 0,
814                    message: ChatMessage {
815                        role: Some("assistant".to_string()),
816                        content: Some("hello".to_string()),
817                        tool_calls: Vec::new(),
818                        ..Default::default()
819                    },
820                    finish_reason: crate::format::FinishReason::Stop,
821                }],
822            }))
823            .mount(&mock_server)
824            .await;
825
826        let url = format!("{}/v1/chat/completions", mock_server.uri());
827
828        let proxy = super::Proxy::builder()
829            .with_custom_provider(CustomProviderConfig {
830                name: "test".to_string(),
831                url,
832                format: ProviderRequestFormat::OpenAi(OpenAiRequestFormatOptions {
833                    transforms: crate::format::ChatRequestTransformation {
834                        supports_message_name: false,
835                        system_in_messages: true,
836                        strip_model_prefix: Some("me/".into()),
837                    },
838                }),
839                label: None,
840                api_key: None,
841                api_key_source: None,
842                headers: BTreeMap::default(),
843                prefix: Some("me/".to_string()),
844            })
845            .build()
846            .await
847            .expect("Building proxy");
848
849        let chan = proxy
850            .send(
851                crate::ProxyRequestOptions {
852                    ..Default::default()
853                },
854                ChatRequest {
855                    model: Some("me/a-test-model".to_string()),
856                    messages: vec![ChatMessage {
857                        role: Some("user".to_string()),
858                        content: Some("hello".to_string()),
859                        tool_calls: Vec::new(),
860                        ..Default::default()
861                    }],
862                    ..Default::default()
863                },
864            )
865            .await
866            .expect("should have succeeded");
867
868        let mut response = collect_response(chan, 1).await.unwrap();
869
870        // ID will be different every time, so zero it for the snapshot
871        response.request_info.id = uuid::Uuid::nil();
872        insta::assert_json_snapshot!(response);
873    }
874
875    #[tokio::test]
876    async fn call_provider_streaming() {
877        let response1 = StreamingChatResponse {
878            created: 1,
879            model: Some("a_model".to_string()),
880            system_fingerprint: Some("abbadada".to_string()),
881            usage: Some(UsageResponse {
882                prompt_tokens: Some(1),
883                completion_tokens: Some(1),
884                total_tokens: Some(2),
885            }),
886            choices: vec![ChatChoiceDelta {
887                index: 0,
888                delta: ChatMessage {
889                    role: Some("assistant".to_string()),
890                    content: Some("hello".to_string()),
891                    tool_calls: Vec::new(),
892                    ..Default::default()
893                },
894                finish_reason: None,
895            }],
896        };
897
898        let response2 = StreamingChatResponse {
899            created: 2,
900            model: None,
901            system_fingerprint: None,
902            usage: Some(UsageResponse {
903                prompt_tokens: Some(1),
904                completion_tokens: Some(1),
905                total_tokens: Some(2),
906            }),
907            choices: vec![ChatChoiceDelta {
908                index: 0,
909                delta: ChatMessage {
910                    role: None,
911                    content: Some(" and hello again".to_string()),
912                    tool_calls: Vec::new(),
913                    ..Default::default()
914                },
915                finish_reason: Some(crate::format::FinishReason::Stop),
916            }],
917        };
918
919        let response_data = format!(
920            "data: {}\n\ndata: {}\n\ndata: [DONE]",
921            serde_json::to_string(&response1).unwrap(),
922            serde_json::to_string(&response2).unwrap(),
923        );
924
925        let mock_server = wiremock::MockServer::start().await;
926        Mock::given(method("POST"))
927            .and(path("/v1/chat/completions"))
928            .respond_with(
929                ResponseTemplate::new(200).set_body_raw(response_data, "text/event-stream"),
930            )
931            .mount(&mock_server)
932            .await;
933
934        let url = format!("{}/v1/chat/completions", mock_server.uri());
935
936        let proxy = super::Proxy::builder()
937            .with_custom_provider(CustomProviderConfig {
938                name: "test".to_string(),
939                url,
940                format: ProviderRequestFormat::OpenAi(OpenAiRequestFormatOptions {
941                    transforms: crate::format::ChatRequestTransformation {
942                        supports_message_name: false,
943                        system_in_messages: true,
944                        strip_model_prefix: Some("me/".into()),
945                    },
946                }),
947                label: None,
948                api_key: None,
949                api_key_source: None,
950                headers: BTreeMap::default(),
951                prefix: Some("me/".to_string()),
952            })
953            .build()
954            .await
955            .expect("Building proxy");
956
957        let chan = proxy
958            .send(
959                crate::ProxyRequestOptions {
960                    ..Default::default()
961                },
962                ChatRequest {
963                    model: Some("me/a-test-model".to_string()),
964                    messages: vec![ChatMessage {
965                        role: Some("user".to_string()),
966                        content: Some("hello".to_string()),
967                        tool_calls: Vec::new(),
968                        ..Default::default()
969                    }],
970                    stream: true,
971                    ..Default::default()
972                },
973            )
974            .await
975            .expect("should have succeeded");
976
977        let mut response = collect_response(chan, 1).await.unwrap();
978
979        // ID will be different every time, so zero it for the snapshot
980        response.request_info.id = uuid::Uuid::nil();
981        insta::assert_json_snapshot!(response);
982    }
983}