1use crate::clients::base::{LLMClient, LLMRequestOptions, LLMResponse, TextResponse, ToolCall};
8use crate::context::manager::ContextManager;
9use crate::error::StreamError;
10use crate::guardrails::{FinalResponseScorer, StepEnforcer, ToolCallScorer};
11use crate::proxy::{
12 extract_passthrough, extract_sampling, openai_to_messages, strip_respond_calls,
13 OpenAiMessageError,
14};
15use crate::schema_compression::{
16 compress_tool_schemas, patch_anthropic_tool_schemas, SchemaCompressionMode,
17};
18use crate::tool_output::{ToolOutputCompressionConfig, ToolOutputCompressionState};
19use crate::tool_policy::{
20 evaluate_tool_call_policy, ToolCallPolicyConfig, ToolCallPolicyRequestState,
21};
22use crate::tools::respond::RESPOND_TOOL_NAME;
23use anyllm_translate::anthropic::streaming::StreamEvent;
24use futures_core::Stream;
25use indexmap::IndexSet;
26use serde_json::Value;
27use std::fmt;
28use std::pin::Pin;
29use std::sync::Arc;
30use tokio::sync::Mutex;
31
32mod anthropic;
33mod classifier_log;
34mod compression;
35mod nudge;
36mod passthrough;
37mod prior_tool_results;
38mod request_contract;
39mod response_shape;
40mod scoring;
41mod telemetry;
42mod tool_specs;
43mod training_capture;
44
45pub use anthropic::{
46 handle_anthropic_messages, handle_anthropic_messages_with_scorer,
47 handle_anthropic_messages_with_scorers,
48 handle_anthropic_messages_with_scorers_and_tool_controls,
49 handle_anthropic_messages_with_scorers_and_tool_output_compression,
50 handle_anthropic_messages_with_scorers_tool_controls_and_headers,
51};
52use compression::{
53 compress_proxy_tool_results,
54 init_proxy_tool_output_compression_log_sink_from_env as init_compression_log_sink_from_env,
55 patch_anthropic_tool_results,
56 shutdown_proxy_tool_output_compression_log_sink as shutdown_compression_log_sink,
57};
58use nudge::{
59 emit_proxy_classifier_nudge_or_error, emit_proxy_final_response_tool_nudge_or_error,
60 emit_proxy_step_nudge_or_error, emit_proxy_tool_policy_nudge_or_error,
61 emit_proxy_user_classifier_nudge_or_error, synthetic_respond_tool_call,
62};
63pub use passthrough::run_passthrough;
64use prior_tool_results::record_completed_proxy_tool_results;
65#[cfg(test)]
66use request_contract::sanitize_guarded_anthropic_body;
67use request_contract::{
68 add_proxy_respond_tool_if_needed, extract_forge_bool_field, extract_forge_debug_context,
69 extract_proxy_step_contract, extract_schema_compression, extract_stream_include_usage,
70 extract_tool_call_policy_config, extract_tool_output_compression_config,
71 sanitize_guarded_request_options, strip_forge_extension_from_body,
72 validate_proxy_step_contract, FORGE_EXTENSION_FIELD, FORGE_REQUIRED_STEPS_FIELD,
73 FORGE_RETURN_RAW_ON_GUARDRAIL_FAILURE_FIELD,
74};
75#[cfg(test)]
76use response_shape::{collect_anthropic_events, collect_openai_events};
77use response_shape::{text_content_result, text_response_result, tool_calls_result};
78use scoring::{score_proxy_final_text, score_proxy_final_tool_calls, score_proxy_tool_calls};
79pub use tool_specs::parse_tool_specs;
80
81pub fn init_proxy_classifier_log_sink_from_env() {
83 classifier_log::init_proxy_classifier_log_sink_from_env();
84}
85
86pub fn init_proxy_training_capture_sink_from_env() {
88 training_capture::init_proxy_training_capture_sink_from_env();
89}
90
91pub fn init_proxy_tool_output_compression_log_sink_from_env() {
93 init_compression_log_sink_from_env();
94}
95
96pub fn shutdown_proxy_classifier_log_sink() {
98 classifier_log::shutdown_proxy_classifier_log_sink();
99}
100
101pub fn shutdown_proxy_training_capture_sink() {
103 training_capture::shutdown_proxy_training_capture_sink();
104}
105
106pub fn shutdown_proxy_tool_output_compression_log_sink() {
108 shutdown_compression_log_sink();
109}
110
111pub type OpenAiEventStream = Pin<Box<dyn Stream<Item = Result<Value, StreamError>> + Send>>;
113
114pub type AnthropicEventStream =
116 Pin<Box<dyn Stream<Item = Result<StreamEvent, StreamError>> + Send>>;
117
118pub enum HandlerResult {
120 Response(Value),
122 StreamBody(OpenAiEventStream),
124 AnthropicResponse(Value),
126 AnthropicStreamBody(AnthropicEventStream),
128}
129
130const PROXY_STEP_INDEX: i64 = 0;
131
132impl fmt::Debug for HandlerResult {
133 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134 match self {
135 Self::Response(value) => f.debug_tuple("Response").field(value).finish(),
136 Self::StreamBody(_) => f.write_str("StreamBody(<openai event stream>)"),
137 Self::AnthropicResponse(value) => {
138 f.debug_tuple("AnthropicResponse").field(value).finish()
139 }
140 Self::AnthropicStreamBody(_) => {
141 f.write_str("AnthropicStreamBody(<anthropic event stream>)")
142 }
143 }
144 }
145}
146
147#[derive(Debug, Clone, PartialEq, Eq)]
149pub enum HandlerError {
150 BadRequest(String),
152 Upstream(String),
155 UpstreamStatus {
158 message: String,
160 status: i64,
162 },
163}
164
165impl HandlerError {
166 pub fn message(&self) -> &str {
168 match self {
169 Self::BadRequest(message)
170 | Self::Upstream(message)
171 | Self::UpstreamStatus { message, .. } => message,
172 }
173 }
174}
175
176impl fmt::Display for HandlerError {
177 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178 write!(f, "{}", self.message())
179 }
180}
181
182impl std::error::Error for HandlerError {}
183
184impl From<OpenAiMessageError> for HandlerError {
185 fn from(error: OpenAiMessageError) -> Self {
186 Self::BadRequest(error.to_string())
187 }
188}
189
190fn upstream_handler_error(err: crate::error::ForgeError) -> HandlerError {
194 use crate::error::ForgeError;
195 match &err {
196 ForgeError::Backend(backend) => HandlerError::UpstreamStatus {
198 message: backend.to_string(),
199 status: backend.status_code(),
200 },
201 ForgeError::Stream(stream) => {
206 match crate::error::BackendError::status_from_display(&stream.message) {
207 Some(status) => HandlerError::UpstreamStatus {
208 message: stream.to_string(),
209 status,
210 },
211 None => HandlerError::Upstream(stream.to_string()),
212 }
213 }
214 _ => HandlerError::Upstream(err.to_string()),
215 }
216}
217
218fn upstream_handler_error_from_message(message: String) -> HandlerError {
224 match crate::error::BackendError::status_from_display(&message) {
225 Some(status) => HandlerError::UpstreamStatus { message, status },
226 None => HandlerError::Upstream(message),
227 }
228}
229
230pub enum AnthropicHandlerResult {
232 Response(Value),
234 StreamBody(AnthropicEventStream),
236}
237
238impl fmt::Debug for AnthropicHandlerResult {
239 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
240 match self {
241 Self::Response(value) => f.debug_tuple("Response").field(value).finish(),
242 Self::StreamBody(_) => f.write_str("StreamBody(<anthropic event stream>)"),
243 }
244 }
245}
246
247#[derive(Debug)]
249pub enum AnthropicHandlerError {
250 BadRequest(String),
252 Upstream(String),
255 UpstreamStatus {
258 message: String,
260 status: i64,
262 },
263 Internal(String),
265}
266
267impl AnthropicHandlerError {
268 pub fn message(&self) -> &str {
270 match self {
271 Self::BadRequest(message)
272 | Self::Upstream(message)
273 | Self::UpstreamStatus { message, .. }
274 | Self::Internal(message) => message,
275 }
276 }
277}
278
279#[allow(clippy::too_many_arguments)]
288pub async fn handle_chat_completions<C: LLMClient + 'static>(
289 body: &Value,
290 client: &Arc<C>,
291 context_manager: &Arc<Mutex<ContextManager>>,
292 max_retries: i32,
293 rescue_enabled: bool,
294) -> Result<HandlerResult, HandlerError> {
295 handle_chat_completions_with_scorer(
296 body,
297 client,
298 context_manager,
299 max_retries,
300 rescue_enabled,
301 None,
302 )
303 .await
304}
305
306#[allow(clippy::too_many_arguments)]
308pub async fn handle_chat_completions_with_scorer<C: LLMClient + 'static>(
309 body: &Value,
310 client: &Arc<C>,
311 context_manager: &Arc<Mutex<ContextManager>>,
312 max_retries: i32,
313 rescue_enabled: bool,
314 scorer: Option<Arc<dyn ToolCallScorer>>,
315) -> Result<HandlerResult, HandlerError> {
316 handle_chat_completions_with_scorers(
317 body,
318 client,
319 context_manager,
320 max_retries,
321 rescue_enabled,
322 scorer,
323 None,
324 )
325 .await
326}
327
328#[allow(clippy::too_many_arguments)]
330pub async fn handle_chat_completions_with_scorers<C: LLMClient + 'static>(
331 body: &Value,
332 client: &Arc<C>,
333 context_manager: &Arc<Mutex<ContextManager>>,
334 max_retries: i32,
335 rescue_enabled: bool,
336 scorer: Option<Arc<dyn ToolCallScorer>>,
337 final_response_scorer: Option<Arc<dyn FinalResponseScorer>>,
338) -> Result<HandlerResult, HandlerError> {
339 handle_chat_completions_with_scorers_and_tool_controls(
340 body,
341 client,
342 context_manager,
343 max_retries,
344 rescue_enabled,
345 scorer,
346 final_response_scorer,
347 ToolOutputCompressionConfig::disabled(),
348 None,
349 ToolCallPolicyConfig::disabled(),
350 SchemaCompressionMode::Disabled,
351 )
352 .await
353}
354
355#[allow(clippy::too_many_arguments)]
357pub async fn handle_chat_completions_with_scorers_and_tool_output_compression<
358 C: LLMClient + 'static,
359>(
360 body: &Value,
361 client: &Arc<C>,
362 context_manager: &Arc<Mutex<ContextManager>>,
363 max_retries: i32,
364 rescue_enabled: bool,
365 scorer: Option<Arc<dyn ToolCallScorer>>,
366 final_response_scorer: Option<Arc<dyn FinalResponseScorer>>,
367 default_tool_output_compression: ToolOutputCompressionConfig,
368 tool_output_state: Option<Arc<ToolOutputCompressionState>>,
369) -> Result<HandlerResult, HandlerError> {
370 handle_chat_completions_with_scorers_and_tool_controls(
371 body,
372 client,
373 context_manager,
374 max_retries,
375 rescue_enabled,
376 scorer,
377 final_response_scorer,
378 default_tool_output_compression,
379 tool_output_state,
380 ToolCallPolicyConfig::disabled(),
381 SchemaCompressionMode::Disabled,
382 )
383 .await
384}
385
386#[allow(clippy::too_many_arguments)]
388pub async fn handle_chat_completions_with_scorers_and_tool_controls<C: LLMClient + 'static>(
389 body: &Value,
390 client: &Arc<C>,
391 context_manager: &Arc<Mutex<ContextManager>>,
392 max_retries: i32,
393 rescue_enabled: bool,
394 scorer: Option<Arc<dyn ToolCallScorer>>,
395 final_response_scorer: Option<Arc<dyn FinalResponseScorer>>,
396 default_tool_output_compression: ToolOutputCompressionConfig,
397 tool_output_state: Option<Arc<ToolOutputCompressionState>>,
398 default_tool_call_policy: ToolCallPolicyConfig,
399 default_schema_compression: SchemaCompressionMode,
400) -> Result<HandlerResult, HandlerError> {
401 handle_chat_completions_impl(
402 body,
403 client,
404 context_manager,
405 max_retries,
406 rescue_enabled,
407 None,
408 None,
409 scorer,
410 final_response_scorer,
411 default_tool_output_compression,
412 tool_output_state,
413 default_tool_call_policy,
414 default_schema_compression,
415 )
416 .await
417}
418
419#[allow(clippy::too_many_arguments)]
420pub(super) async fn handle_chat_completions_impl<C: LLMClient + 'static>(
421 body: &Value,
422 client: &Arc<C>,
423 context_manager: &Arc<Mutex<ContextManager>>,
424 max_retries: i32,
425 rescue_enabled: bool,
426 inbound_anthropic_body: Option<Value>,
427 anthropic_headers: Option<Vec<(String, String)>>,
428 scorer: Option<Arc<dyn ToolCallScorer>>,
429 final_response_scorer: Option<Arc<dyn FinalResponseScorer>>,
430 default_tool_output_compression: ToolOutputCompressionConfig,
431 tool_output_state: Option<Arc<ToolOutputCompressionState>>,
432 default_tool_call_policy: ToolCallPolicyConfig,
433 default_schema_compression: SchemaCompressionMode,
434) -> Result<HandlerResult, HandlerError> {
435 let messages = body
436 .get("messages")
437 .and_then(|m| m.as_array())
438 .ok_or_else(|| HandlerError::BadRequest("missing or invalid messages field".to_string()))?;
439
440 let model_name = body
441 .get("model")
442 .and_then(|m| m.as_str())
443 .unwrap_or("unknown");
444
445 let stream = body
446 .get("stream")
447 .and_then(|s| s.as_bool())
448 .unwrap_or(false);
449 let stream_include_usage = extract_stream_include_usage(body)?;
450
451 let tools_raw = match body.get("tools") {
452 Some(Value::Array(tools)) => tools.clone(),
453 Some(_) => {
454 return Err(HandlerError::BadRequest(
455 "tools must be an array".to_string(),
456 ));
457 }
458 None => Vec::new(),
459 };
460 let step_contract = extract_proxy_step_contract(body)?;
461 let return_raw_on_guardrail_failure =
462 extract_forge_bool_field(body, FORGE_RETURN_RAW_ON_GUARDRAIL_FAILURE_FIELD)?;
463 let tool_output_compression =
464 extract_tool_output_compression_config(body, &default_tool_output_compression)?;
465 let schema_compression = extract_schema_compression(body, default_schema_compression)?;
466 let forge_debug_context = extract_forge_debug_context(body)?;
467 let tool_call_policy = extract_tool_call_policy_config(body, &default_tool_call_policy)?;
468
469 let sampling = extract_sampling(body);
470 let mut passthrough = extract_passthrough(body);
471 if let Some(raw) = inbound_anthropic_body.as_ref() {
472 preserve_rebuilt_anthropic_fields(raw, &mut passthrough);
473 }
474 let mut request_options = LLMRequestOptions {
475 sampling,
476 passthrough,
477 inbound_anthropic_body: inbound_anthropic_body
478 .map(strip_forge_extension_from_body)
479 .map(Arc::new),
480 initial_openai_messages: None,
481 anthropic_headers,
482 preserve_provider_response: false,
483 };
484
485 let mut internal_msgs = openai_to_messages(messages)?;
487 let tool_output_updates = compress_proxy_tool_results(
488 &mut internal_msgs,
489 &tool_output_compression,
490 tool_output_state.as_deref(),
491 forge_debug_context.as_ref(),
492 );
493 if !tool_output_updates.is_empty() {
494 if let Some(body) = request_options.inbound_anthropic_body.take() {
495 let mut patched = body.as_ref().clone();
496 if patch_anthropic_tool_results(&mut patched, &tool_output_updates) {
497 request_options.inbound_anthropic_body = Some(Arc::new(patched));
498 } else {
499 tracing::warn!(
500 "failed to patch compressed tool outputs into Anthropic request body; falling back to rebuilt body which may discard custom metadata or cache_control flags"
501 );
502 }
503 }
504 }
505 if schema_compression != SchemaCompressionMode::Disabled {
507 if let Some(body) = request_options.inbound_anthropic_body.as_ref() {
508 let mut patched = body.as_ref().clone();
509 if patch_anthropic_tool_schemas(&mut patched, schema_compression) {
510 request_options.inbound_anthropic_body = Some(Arc::new(patched));
511 }
512 }
513 }
514 if request_options.inbound_anthropic_body.is_some() {
515 request_options.preserve_provider_response = true;
516 request_options.initial_openai_messages = Some(Arc::from(
517 crate::core::inference::fold_and_serialize(
518 &internal_msgs,
519 client.api_format().as_str(),
520 )
521 .into_boxed_slice(),
522 ));
523 }
524
525 if tools_raw.is_empty() {
527 if let Some(contract) = step_contract.as_ref() {
528 if !contract.required_steps.is_empty() {
529 return Err(HandlerError::BadRequest(format!(
530 "{FORGE_EXTENSION_FIELD}.{FORGE_REQUIRED_STEPS_FIELD} requires tools"
531 )));
532 }
533 }
534 let api_format = client.api_format().as_str();
535 let serialized = crate::core::inference::fold_and_serialize(&internal_msgs, api_format);
536 return run_passthrough(
537 client,
538 &serialized,
539 None,
540 request_options,
541 model_name,
542 stream,
543 stream_include_usage,
544 )
545 .await
546 .map_err(upstream_handler_error_from_message);
547 }
548
549 let mut tool_specs = parse_tool_specs(&tools_raw)?;
552 compress_tool_schemas(&mut tool_specs, schema_compression);
553 let respond_injected =
554 add_proxy_respond_tool_if_needed(&mut tool_specs, step_contract.as_ref());
555
556 let tool_names: IndexSet<String> = tool_specs.iter().map(|s| s.name.clone()).collect();
557 let step_contract = validate_proxy_step_contract(step_contract, &tool_names, respond_injected)?;
558 let request_options =
559 sanitize_guarded_request_options(request_options, step_contract.as_ref())?;
560 let validator = crate::guardrails::ResponseValidator::from_tool_specs(
561 tool_specs.clone(),
562 rescue_enabled,
563 None,
564 );
565 let mut error_tracker = crate::guardrails::ErrorTracker::new(max_retries, 2);
566 let mut tool_call_counter = 0;
567 let mut step_enforcer = step_contract.map(|contract| {
568 let mut enforcer = StepEnforcer::new(
569 contract.required_steps,
570 contract.terminal_tools.into_iter().collect(),
571 None,
572 3,
573 2,
574 );
575 record_completed_proxy_tool_results(messages, &internal_msgs, &mut enforcer);
576 enforcer
577 });
578
579 let mut accepted_usage = None;
580 let mut accepted_usage_details = None;
581 let mut accepted_provider_response = None;
582 let mut accepted_provider_events = None;
583 let mut tool_call_policy_state = ToolCallPolicyRequestState::new();
584 let response = loop {
585 let step_hint = step_enforcer
586 .as_ref()
587 .map(StepEnforcer::summary_hint)
588 .unwrap_or_default();
589 let inference_result = crate::core::inference::run_inference_with_options_shared_context(
590 &mut internal_msgs,
591 client.as_ref(),
592 context_manager.as_ref(),
593 &validator,
594 &mut error_tracker,
595 &tool_specs,
596 &mut tool_call_counter,
597 PROXY_STEP_INDEX,
598 &step_hint,
599 Some(max_retries + 1),
600 stream,
601 None,
602 request_options.clone(),
603 )
604 .await;
605
606 let result = match inference_result {
607 Ok(Some(result)) => result,
608 Ok(None) => break LLMResponse::Text(TextResponse::new("")),
609 Err(crate::error::ForgeError::ToolCall(err)) => {
610 telemetry::capture_guardrail_exhausted(
611 "deterministic_tool_validation_exhausted",
612 &[],
613 &[],
614 Some(error_tracker.consecutive_retries()),
615 Some(error_tracker.max_retries()),
616 Some(stream),
617 );
618 if !return_raw_on_guardrail_failure {
619 return Err(HandlerError::Upstream(format!(
620 "model failed guarded tool-call validation after retries: {}",
621 err
622 )));
623 }
624 let raw = err.raw_response.unwrap_or_default();
625 let usage = client.last_usage();
626 let usage_details = client.last_usage_details();
627 return Ok(text_content_result(
628 &raw,
629 model_name,
630 stream,
631 stream_include_usage,
632 usage.as_ref(),
633 usage_details.as_ref(),
634 ));
635 }
636 Err(err) => return Err(upstream_handler_error(err)),
637 };
638
639 tool_call_counter = result.tool_call_counter;
640
641 let result_usage = result.usage;
642 let result_usage_details = result.usage_details;
643 let result_provider_response = result.provider_response;
644 let result_provider_events = result.provider_events;
645 let response = result.response;
646 let Some(enforcer) = step_enforcer.as_mut() else {
647 match response {
648 LLMResponse::ToolCalls(tool_calls) => {
649 if let Some(nudge) = evaluate_tool_call_policy(
650 &tool_calls,
651 &tool_specs,
652 &tool_call_policy,
653 &mut tool_call_policy_state,
654 ) {
655 emit_proxy_tool_policy_nudge_or_error(
656 &mut error_tracker,
657 tool_calls,
658 &mut internal_msgs,
659 &mut tool_call_counter,
660 &nudge.content,
661 )
662 .map_err(HandlerError::Upstream)?;
663 continue;
664 }
665 if let Some(nudge) = score_proxy_tool_calls(
666 scorer.clone(),
667 &internal_msgs,
668 &tool_calls,
669 None,
670 &tool_specs,
671 )
672 .await
673 {
674 emit_proxy_classifier_nudge_or_error(
675 &mut error_tracker,
676 tool_calls,
677 &mut internal_msgs,
678 &mut tool_call_counter,
679 &nudge,
680 )
681 .map_err(HandlerError::Upstream)?;
682 continue;
683 }
684 if let Some(nudge) = score_proxy_final_tool_calls(
685 final_response_scorer.clone(),
686 &internal_msgs,
687 &tool_calls,
688 None,
689 &tool_specs,
690 )
691 .await
692 {
693 emit_proxy_final_response_tool_nudge_or_error(
694 &mut error_tracker,
695 tool_calls,
696 &mut internal_msgs,
697 &mut tool_call_counter,
698 &nudge,
699 )
700 .map_err(HandlerError::Upstream)?;
701 continue;
702 }
703 accepted_usage = result_usage;
704 accepted_usage_details = result_usage_details;
705 accepted_provider_response = result_provider_response;
706 accepted_provider_events = result_provider_events;
707 break LLMResponse::ToolCalls(tool_calls);
708 }
709 LLMResponse::Text(text) => {
710 if let Some(nudge) = score_proxy_final_text(
711 final_response_scorer.clone(),
712 &internal_msgs,
713 &text.content,
714 None,
715 &tool_specs,
716 )
717 .await
718 {
719 emit_proxy_user_classifier_nudge_or_error(
720 &mut error_tracker,
721 &mut internal_msgs,
722 &nudge,
723 )
724 .map_err(HandlerError::Upstream)?;
725 continue;
726 }
727 accepted_usage = result_usage;
728 accepted_usage_details = result_usage_details;
729 accepted_provider_response = result_provider_response;
730 accepted_provider_events = result_provider_events;
731 break LLMResponse::Text(text);
732 }
733 }
734 };
735
736 match response {
737 LLMResponse::ToolCalls(tool_calls) => {
738 if !enforcer.is_satisfied() {
739 let step_check = enforcer.check(&tool_calls);
740 if step_check.needs_nudge {
741 emit_proxy_step_nudge_or_error(
742 enforcer,
743 step_check,
744 tool_calls,
745 &mut internal_msgs,
746 &mut tool_call_counter,
747 )
748 .map_err(HandlerError::Upstream)?;
749 continue;
750 }
751 }
752
753 if let Some(nudge) = evaluate_tool_call_policy(
754 &tool_calls,
755 &tool_specs,
756 &tool_call_policy,
757 &mut tool_call_policy_state,
758 ) {
759 emit_proxy_tool_policy_nudge_or_error(
760 &mut error_tracker,
761 tool_calls,
762 &mut internal_msgs,
763 &mut tool_call_counter,
764 &nudge.content,
765 )
766 .map_err(HandlerError::Upstream)?;
767 continue;
768 }
769 if let Some(nudge) = score_proxy_tool_calls(
770 scorer.clone(),
771 &internal_msgs,
772 &tool_calls,
773 Some(enforcer),
774 &tool_specs,
775 )
776 .await
777 {
778 emit_proxy_classifier_nudge_or_error(
779 &mut error_tracker,
780 tool_calls,
781 &mut internal_msgs,
782 &mut tool_call_counter,
783 &nudge,
784 )
785 .map_err(HandlerError::Upstream)?;
786 continue;
787 }
788 if let Some(nudge) = score_proxy_final_tool_calls(
789 final_response_scorer.clone(),
790 &internal_msgs,
791 &tool_calls,
792 Some(enforcer),
793 &tool_specs,
794 )
795 .await
796 {
797 emit_proxy_final_response_tool_nudge_or_error(
798 &mut error_tracker,
799 tool_calls,
800 &mut internal_msgs,
801 &mut tool_call_counter,
802 &nudge,
803 )
804 .map_err(HandlerError::Upstream)?;
805 continue;
806 }
807 accepted_usage = result_usage;
808 accepted_usage_details = result_usage_details;
809 accepted_provider_response = result_provider_response;
810 accepted_provider_events = result_provider_events;
811 break LLMResponse::ToolCalls(tool_calls);
812 }
813 LLMResponse::Text(text) => {
814 if !enforcer.is_satisfied() {
815 let tool_calls = vec![synthetic_respond_tool_call(&text)];
816 let step_check = enforcer.check(&tool_calls);
817 if step_check.needs_nudge {
818 emit_proxy_step_nudge_or_error(
819 enforcer,
820 step_check,
821 tool_calls,
822 &mut internal_msgs,
823 &mut tool_call_counter,
824 )
825 .map_err(HandlerError::Upstream)?;
826 continue;
827 }
828 }
829
830 if let Some(nudge) = score_proxy_final_text(
831 final_response_scorer.clone(),
832 &internal_msgs,
833 &text.content,
834 Some(enforcer),
835 &tool_specs,
836 )
837 .await
838 {
839 emit_proxy_user_classifier_nudge_or_error(
840 &mut error_tracker,
841 &mut internal_msgs,
842 &nudge,
843 )
844 .map_err(HandlerError::Upstream)?;
845 continue;
846 }
847 accepted_usage = result_usage;
848 accepted_usage_details = result_usage_details;
849 accepted_provider_response = result_provider_response;
850 accepted_provider_events = result_provider_events;
851 break LLMResponse::Text(text);
852 }
853 }
854 };
855
856 let usage = accepted_usage;
857 let usage_details = accepted_usage_details;
858 let provider_response = accepted_provider_response;
859 let provider_events = accepted_provider_events;
860
861 let handler_result = match response {
862 LLMResponse::Text(ref text) => text_response_result(
863 text,
864 model_name,
865 stream,
866 stream_include_usage,
867 usage.as_ref(),
868 usage_details.as_ref(),
869 ),
870 LLMResponse::ToolCalls(ref calls) => {
871 let (real_calls, respond_text) = strip_respond_calls(calls);
872 training_capture::emit_proxy_training_tool_call_candidates(
873 &internal_msgs,
874 &real_calls,
875 step_enforcer.as_ref(),
876 &tool_specs,
877 );
878
879 if real_calls.is_empty() {
880 let text = respond_text.unwrap_or_default();
882 text_content_result(
883 &text,
884 model_name,
885 stream,
886 stream_include_usage,
887 usage.as_ref(),
888 usage_details.as_ref(),
889 )
890 } else if respond_text.is_none() {
891 if stream {
892 if let Some(events) = provider_events {
893 anthropic_stream_result(events)
894 } else {
895 tool_calls_result(
896 &real_calls,
897 model_name,
898 stream,
899 stream_include_usage,
900 usage.as_ref(),
901 usage_details.as_ref(),
902 )
903 }
904 } else if let Some(value) = provider_response {
905 HandlerResult::AnthropicResponse(value)
906 } else {
907 tool_calls_result(
908 &real_calls,
909 model_name,
910 stream,
911 stream_include_usage,
912 usage.as_ref(),
913 usage_details.as_ref(),
914 )
915 }
916 } else {
917 tool_calls_result(
919 &real_calls,
920 model_name,
921 stream,
922 stream_include_usage,
923 usage.as_ref(),
924 usage_details.as_ref(),
925 )
926 }
927 }
928 };
929
930 Ok(handler_result)
931}
932
933fn preserve_rebuilt_anthropic_fields(
934 raw: &Value,
935 passthrough: &mut Option<serde_json::Map<String, Value>>,
936) {
937 let mut insert = |key: &str| {
938 let Some(value) = raw.get(key) else {
939 return;
940 };
941 passthrough
942 .get_or_insert_with(serde_json::Map::new)
943 .entry(key.to_string())
944 .or_insert_with(|| value.clone());
945 };
946 insert("thinking");
947 insert("output_config");
948}
949
950pub fn filter_respond(calls: &[ToolCall]) -> Vec<ToolCall> {
952 calls
953 .iter()
954 .filter(|c| c.tool != RESPOND_TOOL_NAME)
955 .cloned()
956 .collect()
957}
958
959pub fn process_response(response: &LLMResponse, model_name: &str, stream: bool) -> HandlerResult {
961 match response {
962 LLMResponse::ToolCalls(calls) => {
963 tool_calls_result(calls, model_name, stream, false, None, None)
964 }
965 LLMResponse::Text(text) => {
966 text_response_result(text, model_name, stream, false, None, None)
967 }
968 }
969}
970
971fn anthropic_stream_result(events: Vec<Value>) -> HandlerResult {
972 HandlerResult::AnthropicStreamBody(Box::pin(async_stream::stream! {
973 for event in events {
974 match serde_json::from_value::<StreamEvent>(event) {
975 Ok(event) => yield Ok(event),
976 Err(err) => {
977 yield Err(StreamError::new(err.to_string()));
978 return;
979 }
980 }
981 }
982 }))
983}
984
985#[cfg(test)]
986mod tests;