1use std::{borrow::Cow, fmt::Debug, str::FromStr, sync::Arc, time::Duration};
6
7pub mod builder;
8pub mod config;
9pub mod database;
10pub mod error;
11pub mod format;
12mod provider_lookup;
13pub mod providers;
14pub mod request;
15mod response;
16mod streaming;
17#[cfg(test)]
18mod testing;
19pub mod workflow_events;
20
21use builder::ProxyBuilder;
22use config::{AliasConfig, ApiKeyConfig};
23use database::logging::{LogSender, ProxyLogEntry, ProxyLogEvent};
24pub use error::Error;
25use error_stack::{Report, ResultExt};
26use format::{
27 ChatRequest, RequestInfo, SingleChatResponse, StreamingResponse, StreamingResponseReceiver,
28 StreamingResponseSender,
29};
30use http::HeaderMap;
31use provider_lookup::{ModelLookupResult, ProviderLookup};
32use providers::ChatModelProvider;
33use request::RetryOptions;
34pub use response::{collect_response, CollectedResponse};
35use response::{handle_response, record_error};
36use serde::{de::DeserializeOwned, Deserialize, Serialize};
37use serde_with::{serde_as, DurationMilliSeconds};
38use smallvec::{smallvec, SmallVec};
39use tracing::{instrument, Span};
40use uuid::Uuid;
41use workflow_events::{EventPayload, WorkflowEvent};
42
43use crate::request::try_model_choices;
44
45pub type AnyChatModelProvider = Arc<dyn ChatModelProvider>;
46
47#[derive(Debug, Serialize)]
48pub struct ProxiedChatResponseMeta {
49 pub id: Uuid,
53 pub provider: String,
55 pub response_meta: Option<serde_json::Value>,
56 pub was_rate_limited: bool,
57}
58
59#[derive(Debug, Serialize)]
60pub struct ProxiedChatResponse {
61 #[serde(flatten)]
62 pub response: SingleChatResponse,
63 pub meta: ProxiedChatResponseMeta,
64}
65
66#[derive(Debug)]
68pub struct Proxy {
69 log_tx: Option<LogSender>,
70 log_task: Option<tokio::task::JoinHandle<()>>,
71 lookup: ProviderLookup,
72 default_timeout: Option<Duration>,
73}
74
75impl Proxy {
76 pub fn builder() -> ProxyBuilder {
78 ProxyBuilder::new()
79 }
80
81 pub async fn record_event(&self, body: EventPayload) -> Uuid {
84 let id = Uuid::now_v7();
85
86 let Some(log_tx) = &self.log_tx else {
87 return id;
88 };
89
90 let log_entry = ProxyLogEntry::Proxied(Box::new(ProxyLogEvent::from_payload(id, body)));
91
92 log_tx.send_async(smallvec![log_entry]).await.ok();
93
94 id
95 }
96
97 pub async fn record_workflow_event(&self, event: WorkflowEvent) {
99 let Some(log_tx) = &self.log_tx else {
100 return;
101 };
102
103 log_tx
104 .send_async(smallvec![ProxyLogEntry::Workflow(event)])
105 .await
106 .ok();
107 }
108
109 pub async fn record_event_batch(&self, events: impl Into<SmallVec<[WorkflowEvent; 1]>>) {
111 let Some(log_tx) = &self.log_tx else {
112 return;
113 };
114
115 let events = events
116 .into()
117 .into_iter()
118 .map(ProxyLogEntry::Workflow)
119 .collect::<_>();
120
121 log_tx.send_async(events).await.ok();
122 }
123
124 pub async fn send(
125 &self,
126 options: ProxyRequestOptions,
127 body: ChatRequest,
128 ) -> Result<StreamingResponseReceiver, Report<Error>> {
129 let (chunk_tx, chunk_rx) = if body.stream {
130 flume::unbounded()
131 } else {
132 flume::bounded(5)
133 };
134
135 let models = self.lookup.find_model_and_provider(&options, &body)?;
136
137 if models.choices.is_empty() {
138 return Err(Report::new(Error::AliasEmpty(models.alias)));
139 }
140
141 let parent_span = tracing::Span::current();
142 let log_tx = self.log_tx.clone();
143 let default_timeout = self.default_timeout;
144 tokio::task::spawn(async move {
145 Self::send_request(
146 parent_span,
147 options,
148 models,
149 body,
150 default_timeout,
151 chunk_tx,
152 log_tx,
153 )
154 .await
155 });
156 Ok(chunk_rx)
157 }
158
159 #[instrument(
167 name = "llm.send_request",
168 parent=&parent_span,
169 skip(options),
170 fields(
171 error,
172 llm.options=serde_json::to_string(&options).ok(),
173 llm.item_id,
174 llm.finish_reason,
175 llm.latency,
176 llm.total_latency,
177 llm.retries,
178 llm.rate_limited,
179 llm.status_code,
180 llm.meta.application = options.metadata.application,
181 llm.meta.environment = options.metadata.environment,
182 llm.meta.organization_id = options.metadata.organization_id,
183 llm.meta.project_id = options.metadata.project_id,
184 llm.meta.user_id = options.metadata.user_id,
185 llm.meta.workflow_id = options.metadata.workflow_id,
186 llm.meta.workflow_name = options.metadata.workflow_name,
187 llm.meta.run_id = options.metadata.run_id.map(|u| u.to_string()),
188 llm.meta.step = options.metadata.step_id.map(|u| u.to_string()),
189 llm.meta.step_index = options.metadata.step_index,
190 llm.meta.prompt_id = options.metadata.prompt_id,
191 llm.meta.prompt_version = options.metadata.prompt_version,
192 llm.meta.extra,
193 llm.meta.internal_organization_id = options.internal_metadata.organization_id,
194 llm.meta.internal_project_id = options.internal_metadata.project_id,
195 llm.meta.internal_user_id = options.internal_metadata.user_id,
196 llm.vendor,
198 llm.request.model = body.model,
201 llm.prompts,
202 llm.prompts.raw = serde_json::to_string(&body.messages).ok(),
203 llm.request.max_tokens = body.max_tokens,
204 llm.response.model,
205 llm.usage.prompt_tokens,
206 llm.usage.completion_tokens,
207 llm.usage.total_tokens,
208 llm.completions,
209 llm.completions.raw,
210 llm.temperature = body.temperature,
211 llm.top_p = body.top_p,
212 llm.frequency_penalty = body.frequency_penalty,
213 llm.presence_penalty = body.presence_penalty,
214 llm.chat.stop_sequences,
215 llm.user = body.user,
216 )
217 )]
218 async fn send_request(
219 parent_span: Span,
220 options: ProxyRequestOptions,
221 models: ModelLookupResult,
222 body: ChatRequest,
223 default_timeout: Option<Duration>,
224 output_tx: StreamingResponseSender,
225 log_tx: Option<LogSender>,
226 ) {
227 let id = uuid::Uuid::now_v7();
228 let current_span = tracing::Span::current();
229 current_span.record("llm.item_id", id.to_string());
230 if !body.stop.is_empty() {
231 current_span.record(
232 "llm.chat.stop_sequences",
233 serde_json::to_string(&body.stop).ok(),
234 );
235 }
236
237 if let Some(extra) = options.metadata.extra.as_ref().filter(|e| !e.is_empty()) {
238 current_span.record("llm.meta.extra", &serde_json::to_string(extra).ok());
239 }
240
241 let messages_field = if body.messages.len() > 1 {
242 Some(Cow::Owned(
243 body.messages
244 .iter()
245 .filter_map(|m| {
246 let Some(content) = m.content.as_deref() else {
247 return None;
248 };
249
250 Some(format!(
251 "{}: {}",
252 m.name.as_deref().or(m.role.as_deref()).unwrap_or_default(),
253 content
254 ))
255 })
256 .collect::<Vec<_>>()
257 .join("\n\n"),
258 ))
259 } else {
260 body.messages
261 .get(0)
262 .and_then(|m| m.content.as_deref().map(Cow::Borrowed))
263 };
264 current_span.record("llm.prompts", messages_field.as_deref());
265
266 if models.choices.len() == 1 {
267 current_span.record("llm.vendor", models.choices[0].provider.name());
270 }
271
272 tracing::info!(?body, "Starting request");
273
274 let retry = options.retry.clone().unwrap_or_default();
275
276 let (chunk_tx, chunk_rx) = flume::bounded(5);
277
278 let timestamp = chrono::Utc::now();
279 let global_start = tokio::time::Instant::now();
280 let response = try_model_choices(
281 models,
282 options.override_url.clone(),
283 retry,
284 options
285 .timeout
286 .or(default_timeout)
287 .unwrap_or_else(|| Duration::from_millis(60_000)),
288 body.clone(),
289 chunk_tx,
290 )
291 .await;
292
293 let n = body.n.unwrap_or(1) as usize;
294
295 let log_entry = ProxyLogEvent {
297 id,
298 event_type: Cow::Borrowed("chronicle_llm_request"),
299 timestamp,
300 request: Some(body),
301 response: None,
302 total_latency: None,
303 latency: None,
304 num_retries: None,
305 was_rate_limited: None,
306 error: None,
307 options,
308 };
309
310 match response {
311 Ok(res) => {
312 output_tx
313 .send_async(Ok(StreamingResponse::RequestInfo(RequestInfo {
314 id,
315 provider: res.provider.clone(),
316 model: res.model.clone(),
317 num_retries: res.num_retries,
318 was_rate_limited: res.was_rate_limited,
319 })))
320 .await
321 .ok();
322 handle_response(
323 current_span,
324 log_entry,
325 global_start,
326 n,
327 res,
328 chunk_rx,
329 output_tx,
330 log_tx,
331 )
332 .await;
333 }
334 Err(e) => {
335 record_error(
336 log_entry,
337 &e.error,
338 global_start,
339 e.num_retries,
340 e.was_rate_limited,
341 current_span,
342 log_tx.as_ref(),
343 )
344 .await;
345 output_tx.send_async(Err(e.error)).await.ok();
346 }
347 }
348 }
349
350 pub fn set_provider(&self, provider: Arc<dyn ChatModelProvider>) {
352 self.lookup.set_provider(provider);
353 }
354
355 pub fn remove_provider(&self, name: &str) {
357 self.lookup.remove_provider(name);
358 }
359
360 pub fn set_alias(&self, alias: AliasConfig) {
362 self.lookup.set_alias(alias);
363 }
364
365 pub fn remove_alias(&self, name: &str) {
367 self.lookup.remove_alias(name);
368 }
369
370 pub fn set_api_key(&self, api_key: ApiKeyConfig) {
372 self.lookup.set_api_key(api_key);
373 }
374
375 pub fn remove_api_key(&self, name: &str) {
377 self.lookup.remove_api_key(name);
378 }
379
380 pub async fn shutdown(&mut self) {
382 let log_tx = self.log_tx.take();
383 drop(log_tx);
384 let log_task = self.log_task.take();
385 if let Some(log_task) = log_task {
386 log_task.await.ok();
387 }
388 }
389
390 fn validate(&self) -> Vec<String> {
393 self.lookup.validate()
394 }
395}
396
397#[derive(Debug, Clone, Serialize, Deserialize)]
398pub struct ModelAndProvider {
399 pub model: String,
400 pub provider: String,
401 pub api_key: Option<String>,
403 pub api_key_name: Option<String>,
405}
406
407#[serde_as]
408#[derive(Debug, Default, Serialize, Deserialize)]
409pub struct ProxyRequestOptions {
410 pub model: Option<String>,
413 pub provider: Option<String>,
416 pub override_url: Option<String>,
419 pub api_key: Option<String>,
422 #[serde(default)]
427 pub models: Vec<ModelAndProvider>,
428 pub random_choice: Option<bool>,
432 #[serde_as(as = "Option<DurationMilliSeconds>")]
433 pub timeout: Option<std::time::Duration>,
436 pub retry: Option<RetryOptions>,
439
440 #[serde(default)]
442 pub metadata: ProxyRequestMetadata,
443
444 #[serde(skip, default)]
447 pub internal_metadata: ProxyRequestInternalMetadata,
448}
449
450impl ProxyRequestOptions {
451 pub fn merge_request_headers(&mut self, headers: &HeaderMap) -> Result<(), Report<Error>> {
452 get_header_str(&mut self.api_key, headers, "x-chronicle-provider-api-key");
453 get_header_str(&mut self.provider, headers, "x-chronicle-provider");
454 get_header_str(&mut self.model, headers, "x-chronicle-model");
455 get_header_str(&mut self.override_url, headers, "x-chronicle-override-url");
456
457 let models_header = headers
458 .get("x-chronicle-models")
459 .map(|s| serde_json::from_slice::<Vec<ModelAndProvider>>(s.as_bytes()))
460 .transpose()
461 .change_context_lazy(|| {
462 Error::ReadingHeader(
463 "x-chronicle-models".to_string(),
464 "Array of ModelAndProvider",
465 )
466 })?;
467 if let Some(models_header) = models_header {
468 self.models = models_header;
469 }
470
471 get_header_t(
472 &mut self.random_choice,
473 headers,
474 "x-chronicle-random-choice",
475 "boolean",
476 )?;
477 get_header_json(&mut self.retry, headers, "x-chronicle-retry")?;
478
479 let timeout = headers
480 .get("x-chronicle-timeout")
481 .and_then(|s| s.to_str().ok())
482 .map(|s| s.parse::<u64>())
483 .transpose()
484 .change_context_lazy(|| {
485 Error::ReadingHeader("x-chronicle-timeout".to_string(), "integer")
486 })?
487 .map(|s| std::time::Duration::from_millis(s));
488 if timeout.is_some() {
489 self.timeout = timeout;
490 }
491
492 self.metadata.merge_request_headers(headers)?;
493
494 Ok(())
495 }
496
497 pub fn merge_from(&mut self, other: &Self) {
499 if self.model.is_none() {
500 self.model = other.model.clone();
501 }
502 if self.provider.is_none() {
503 self.provider = other.provider.clone();
504 }
505 if self.override_url.is_none() {
506 self.override_url = other.override_url.clone();
507 }
508 if self.api_key.is_none() {
509 self.api_key = other.api_key.clone();
510 }
511 if self.models.is_empty() {
512 self.models = other.models.clone();
513 }
514 if self.random_choice.is_none() {
515 self.random_choice = other.random_choice;
516 }
517 if self.timeout.is_none() {
518 self.timeout = other.timeout;
519 }
520 if self.retry.is_none() {
521 self.retry = other.retry.clone();
522 }
523 self.metadata.merge_from(&other.metadata);
524 self.internal_metadata.merge_from(&other.internal_metadata);
525 }
526}
527
528#[derive(Debug, Serialize, Deserialize, Default)]
529pub struct ProxyRequestInternalMetadata {
532 pub organization_id: Option<String>,
534 pub project_id: Option<String>,
536 pub user_id: Option<String>,
538}
539
540impl ProxyRequestInternalMetadata {
541 pub fn merge_from(&mut self, other: &Self) {
542 if self.organization_id.is_none() {
543 self.organization_id = other.organization_id.clone();
544 }
545 if self.project_id.is_none() {
546 self.project_id = other.project_id.clone();
547 }
548 if self.user_id.is_none() {
549 self.user_id = other.user_id.clone();
550 }
551 }
552}
553
554#[derive(Debug, Serialize, Deserialize, Default)]
555pub struct ProxyRequestMetadata {
559 pub application: Option<String>,
562 pub environment: Option<String>,
565 pub organization_id: Option<String>,
568 pub project_id: Option<String>,
571 pub user_id: Option<String>,
574 pub workflow_id: Option<String>,
577 pub workflow_name: Option<String>,
580 pub run_id: Option<Uuid>,
583 pub step_id: Option<Uuid>,
586 pub step_index: Option<u32>,
589 pub prompt_id: Option<String>,
592 pub prompt_version: Option<u32>,
595
596 #[serde(flatten)]
600 pub extra: Option<serde_json::Map<String, serde_json::Value>>,
601}
602
603impl ProxyRequestMetadata {
604 pub fn merge_request_headers(&mut self, headers: &HeaderMap) -> Result<(), Report<Error>> {
605 get_header_str(&mut self.application, headers, "x-chronicle-application");
606 get_header_str(&mut self.environment, headers, "x-chronicle-environment");
607 get_header_str(
608 &mut self.organization_id,
609 headers,
610 "x-chronicle-organization-id",
611 );
612 get_header_str(&mut self.project_id, headers, "x-chronicle-project-id");
613 get_header_str(&mut self.user_id, headers, "x-chronicle-user-id");
614 get_header_str(&mut self.workflow_id, headers, "x-chronicle-workflow-id");
615 get_header_str(
616 &mut self.workflow_name,
617 headers,
618 "x-chronicle-workflow-name",
619 );
620 get_header_t(&mut self.run_id, headers, "x-chronicle-run-id", "UUID")?;
621 get_header_t(&mut self.step_id, headers, "x-chronicle-step-id", "UUID")?;
622 get_header_t(
623 &mut self.step_index,
624 headers,
625 "x-chronicle-step-index",
626 "integer",
627 )?;
628 get_header_str(&mut self.prompt_id, headers, "x-chronicle-prompt-id");
629 get_header_t(
630 &mut self.prompt_version,
631 headers,
632 "x-chronicle-prompt-version",
633 "integer",
634 )?;
635 get_header_json(&mut self.extra, headers, "x-chronicle-extra-meta")?;
636 Ok(())
637 }
638
639 pub fn merge_from(&mut self, other: &Self) {
641 if self.application.is_none() {
642 self.application = other.application.clone();
643 }
644 if self.environment.is_none() {
645 self.environment = other.environment.clone();
646 }
647 if self.organization_id.is_none() {
648 self.organization_id = other.organization_id.clone();
649 }
650 if self.project_id.is_none() {
651 self.project_id = other.project_id.clone();
652 }
653 if self.user_id.is_none() {
654 self.user_id = other.user_id.clone();
655 }
656 if self.workflow_id.is_none() {
657 self.workflow_id = other.workflow_id.clone();
658 }
659 if self.workflow_name.is_none() {
660 self.workflow_name = other.workflow_name.clone();
661 }
662 if self.run_id.is_none() {
663 self.run_id = other.run_id;
664 }
665 if self.step_id.is_none() {
666 self.step_id = other.step_id;
667 }
668 if self.step_index.is_none() {
669 self.step_index = other.step_index;
670 }
671 if self.prompt_id.is_none() {
672 self.prompt_id = other.prompt_id.clone();
673 }
674 if self.prompt_version.is_none() {
675 self.prompt_version = other.prompt_version;
676 }
677 if self.extra.is_none() {
678 self.extra = other.extra.clone();
679 }
680 }
681}
682
683fn get_header_str(body_value: &mut Option<String>, headers: &HeaderMap, key: &str) {
684 if body_value.is_some() {
685 return;
686 }
687
688 let value = headers
689 .get(key)
690 .and_then(|s| s.to_str().ok())
691 .map(|s| s.to_string());
692
693 if value.is_some() {
694 *body_value = value;
695 }
696}
697
698fn get_header_t<T>(
699 body_value: &mut Option<T>,
700 headers: &HeaderMap,
701 key: &str,
702 expected_format: &'static str,
703) -> Result<(), Report<Error>>
704where
705 T: FromStr,
706 T::Err: std::error::Error + Send + Sync + 'static,
707{
708 if body_value.is_some() {
709 return Ok(());
710 }
711
712 let value = headers
713 .get(key)
714 .and_then(|s| s.to_str().ok())
715 .map(|s| s.parse::<T>())
716 .transpose()
717 .change_context_lazy(|| Error::ReadingHeader(key.to_string(), expected_format))?;
718
719 if value.is_some() {
720 *body_value = value;
721 }
722
723 Ok(())
724}
725
726fn get_header_json<T: DeserializeOwned>(
727 body_value: &mut Option<T>,
728 headers: &HeaderMap,
729 key: &str,
730) -> Result<(), Report<Error>> {
731 if body_value.is_some() {
732 return Ok(());
733 }
734
735 let value = headers
736 .get(key)
737 .and_then(|s| s.to_str().ok())
738 .map(|s| serde_json::from_str(s))
739 .transpose()
740 .change_context_lazy(|| Error::ReadingHeader(key.to_string(), "JSON value"))?;
741
742 if value.is_some() {
743 *body_value = value;
744 }
745
746 Ok(())
747}
748
749#[cfg(test)]
750mod test {
751 use std::collections::BTreeMap;
752
753 use serde_json::json;
754 use uuid::Uuid;
755 use wiremock::{
756 matchers::{method, path},
757 Mock, ResponseTemplate,
758 };
759
760 use crate::{
761 collect_response,
762 config::CustomProviderConfig,
763 format::{
764 ChatChoice, ChatChoiceDelta, ChatMessage, ChatRequest, ChatResponse,
765 StreamingChatResponse, UsageResponse,
766 },
767 providers::custom::{OpenAiRequestFormatOptions, ProviderRequestFormat},
768 ProxyRequestMetadata,
769 };
770
771 #[test]
772 fn deserialize_meta() {
774 let step = Uuid::now_v7();
775 let test_value = json!({
776 "application": "abc",
777 "another": "value",
778 "step_id": step,
779 "third": "fourth",
780 });
781
782 let value: ProxyRequestMetadata =
783 serde_json::from_value(test_value).expect("deserializing");
784
785 println!("{value:#?}");
786 assert_eq!(value.application, Some("abc".to_string()));
787 assert_eq!(value.step_id, Some(step));
788 assert_eq!(
789 value.extra.as_ref().unwrap().get("another").unwrap(),
790 &json!("value")
791 );
792 assert_eq!(
793 value.extra.as_ref().unwrap().get("third").unwrap(),
794 &json!("fourth")
795 );
796 }
797
798 #[tokio::test]
799 async fn call_provider_nonstreaming() {
800 let mock_server = wiremock::MockServer::start().await;
801 Mock::given(method("POST"))
802 .and(path("/v1/chat/completions"))
803 .respond_with(ResponseTemplate::new(200).set_body_json(ChatResponse {
804 created: 1,
805 model: None,
806 system_fingerprint: None,
807 usage: Some(UsageResponse {
808 prompt_tokens: Some(1),
809 completion_tokens: Some(1),
810 total_tokens: Some(2),
811 }),
812 choices: vec![ChatChoice {
813 index: 0,
814 message: ChatMessage {
815 role: Some("assistant".to_string()),
816 content: Some("hello".to_string()),
817 tool_calls: Vec::new(),
818 ..Default::default()
819 },
820 finish_reason: crate::format::FinishReason::Stop,
821 }],
822 }))
823 .mount(&mock_server)
824 .await;
825
826 let url = format!("{}/v1/chat/completions", mock_server.uri());
827
828 let proxy = super::Proxy::builder()
829 .with_custom_provider(CustomProviderConfig {
830 name: "test".to_string(),
831 url,
832 format: ProviderRequestFormat::OpenAi(OpenAiRequestFormatOptions {
833 transforms: crate::format::ChatRequestTransformation {
834 supports_message_name: false,
835 system_in_messages: true,
836 strip_model_prefix: Some("me/".into()),
837 },
838 }),
839 label: None,
840 api_key: None,
841 api_key_source: None,
842 headers: BTreeMap::default(),
843 prefix: Some("me/".to_string()),
844 })
845 .build()
846 .await
847 .expect("Building proxy");
848
849 let chan = proxy
850 .send(
851 crate::ProxyRequestOptions {
852 ..Default::default()
853 },
854 ChatRequest {
855 model: Some("me/a-test-model".to_string()),
856 messages: vec![ChatMessage {
857 role: Some("user".to_string()),
858 content: Some("hello".to_string()),
859 tool_calls: Vec::new(),
860 ..Default::default()
861 }],
862 ..Default::default()
863 },
864 )
865 .await
866 .expect("should have succeeded");
867
868 let mut response = collect_response(chan, 1).await.unwrap();
869
870 response.request_info.id = uuid::Uuid::nil();
872 insta::assert_json_snapshot!(response);
873 }
874
875 #[tokio::test]
876 async fn call_provider_streaming() {
877 let response1 = StreamingChatResponse {
878 created: 1,
879 model: Some("a_model".to_string()),
880 system_fingerprint: Some("abbadada".to_string()),
881 usage: Some(UsageResponse {
882 prompt_tokens: Some(1),
883 completion_tokens: Some(1),
884 total_tokens: Some(2),
885 }),
886 choices: vec![ChatChoiceDelta {
887 index: 0,
888 delta: ChatMessage {
889 role: Some("assistant".to_string()),
890 content: Some("hello".to_string()),
891 tool_calls: Vec::new(),
892 ..Default::default()
893 },
894 finish_reason: None,
895 }],
896 };
897
898 let response2 = StreamingChatResponse {
899 created: 2,
900 model: None,
901 system_fingerprint: None,
902 usage: Some(UsageResponse {
903 prompt_tokens: Some(1),
904 completion_tokens: Some(1),
905 total_tokens: Some(2),
906 }),
907 choices: vec![ChatChoiceDelta {
908 index: 0,
909 delta: ChatMessage {
910 role: None,
911 content: Some(" and hello again".to_string()),
912 tool_calls: Vec::new(),
913 ..Default::default()
914 },
915 finish_reason: Some(crate::format::FinishReason::Stop),
916 }],
917 };
918
919 let response_data = format!(
920 "data: {}\n\ndata: {}\n\ndata: [DONE]",
921 serde_json::to_string(&response1).unwrap(),
922 serde_json::to_string(&response2).unwrap(),
923 );
924
925 let mock_server = wiremock::MockServer::start().await;
926 Mock::given(method("POST"))
927 .and(path("/v1/chat/completions"))
928 .respond_with(
929 ResponseTemplate::new(200).set_body_raw(response_data, "text/event-stream"),
930 )
931 .mount(&mock_server)
932 .await;
933
934 let url = format!("{}/v1/chat/completions", mock_server.uri());
935
936 let proxy = super::Proxy::builder()
937 .with_custom_provider(CustomProviderConfig {
938 name: "test".to_string(),
939 url,
940 format: ProviderRequestFormat::OpenAi(OpenAiRequestFormatOptions {
941 transforms: crate::format::ChatRequestTransformation {
942 supports_message_name: false,
943 system_in_messages: true,
944 strip_model_prefix: Some("me/".into()),
945 },
946 }),
947 label: None,
948 api_key: None,
949 api_key_source: None,
950 headers: BTreeMap::default(),
951 prefix: Some("me/".to_string()),
952 })
953 .build()
954 .await
955 .expect("Building proxy");
956
957 let chan = proxy
958 .send(
959 crate::ProxyRequestOptions {
960 ..Default::default()
961 },
962 ChatRequest {
963 model: Some("me/a-test-model".to_string()),
964 messages: vec![ChatMessage {
965 role: Some("user".to_string()),
966 content: Some("hello".to_string()),
967 tool_calls: Vec::new(),
968 ..Default::default()
969 }],
970 stream: true,
971 ..Default::default()
972 },
973 )
974 .await
975 .expect("should have succeeded");
976
977 let mut response = collect_response(chan, 1).await.unwrap();
978
979 response.request_info.id = uuid::Uuid::nil();
981 insta::assert_json_snapshot!(response);
982 }
983}