1use serde::Deserialize;
20
21use crate::api::llm::LlmRequest;
22use crate::error::{FlowError, Result};
23use crate::json::Json;
24
25use super::request::{
26 AnnotatedLlmRequest, FunctionDefinition, GenerationParams, Message, MessageContent, ToolChoice,
27 ToolChoiceFunction, ToolChoiceFunctionName, ToolDefinition,
28};
29use super::response::{
30 AnnotatedLlmResponse, ApiSpecificResponse, FinishReason, ResponseToolCall, Usage,
31};
32use super::traits::{LlmCodec, LlmResponseCodec};
33
34pub struct AnthropicMessagesCodec;
40
41#[derive(Deserialize)]
46struct RawAnthropicResponse {
47 id: Option<String>,
48 #[serde(rename = "type")]
49 object_type: Option<String>,
50 role: Option<String>,
51 model: Option<String>,
52 content: Option<Vec<Json>>,
53 stop_reason: Option<String>,
54 stop_sequence: Option<String>,
55 service_tier: Option<String>,
56 container: Option<Json>,
57 usage: Option<RawAnthropicUsage>,
58 #[serde(flatten)]
59 extra: serde_json::Map<String, Json>,
60}
61
62#[derive(Deserialize)]
63struct RawAnthropicUsage {
64 input_tokens: Option<u64>,
65 output_tokens: Option<u64>,
66 cache_read_input_tokens: Option<u64>,
67 cache_creation_input_tokens: Option<u64>,
68}
69
70fn map_anthropic_stop_reason(reason: &str) -> FinishReason {
76 match reason {
77 "end_turn" => FinishReason::Complete,
78 "max_tokens" => FinishReason::Length,
79 "tool_use" => FinishReason::ToolUse,
80 other => FinishReason::Unknown(other.to_string()),
81 }
82}
83
84fn json_f64(v: f64) -> Json {
86 serde_json::Number::from_f64(v)
87 .map(Json::Number)
88 .unwrap_or(Json::Null)
89}
90
91const MODELED_REQUEST_KEYS: &[&str] = &[
93 "system",
94 "messages",
95 "model",
96 "max_tokens",
97 "temperature",
98 "top_p",
99 "stop_sequences",
100 "tools",
101 "tool_choice",
102 "metadata",
103 "service_tier",
104];
105
106fn decode_anthropic_tool_choice(val: &Json) -> Option<ToolChoice> {
114 let obj = val.as_object()?;
115 let tc_type = obj.get("type")?.as_str()?;
116 match tc_type {
117 "auto" => Some(ToolChoice::Auto),
118 "any" => Some(ToolChoice::Required),
119 "none" => Some(ToolChoice::None),
120 "tool" => {
121 let name = obj.get("name")?.as_str()?.to_string();
122 Some(ToolChoice::Specific(ToolChoiceFunction {
123 choice_type: "function".into(),
124 function: ToolChoiceFunctionName { name },
125 }))
126 }
127 _ => None,
128 }
129}
130
131fn decode_parallel_tool_calls(val: &Json) -> Option<bool> {
134 let obj = val.as_object()?;
135 obj.get("disable_parallel_tool_use")
136 .and_then(|v| v.as_bool())
137 .map(|disabled| !disabled)
138}
139
140fn encode_anthropic_tool_choice(tc: &ToolChoice) -> Json {
142 match tc {
143 ToolChoice::Auto => serde_json::json!({"type": "auto"}),
144 ToolChoice::Required => serde_json::json!({"type": "any"}),
145 ToolChoice::None => serde_json::json!({"type": "none"}),
146 ToolChoice::Specific(func) => {
147 serde_json::json!({"type": "tool", "name": func.function.name})
148 }
149 }
150}
151
152fn encode_tool_choice_with_parallel_hint(
153 tc: &ToolChoice,
154 parallel_tool_calls: Option<bool>,
155) -> Json {
156 let mut value = encode_anthropic_tool_choice(tc);
157 if let (Some(parallel), Some(obj)) = (parallel_tool_calls, value.as_object_mut()) {
158 obj.insert("disable_parallel_tool_use".into(), Json::Bool(!parallel));
159 }
160 value
161}
162
163fn extract_system_message(system_val: &Json) -> Option<Message> {
167 if let Some(s) = system_val.as_str() {
168 Some(Message::System {
169 content: MessageContent::Text(s.to_string()),
170 name: None,
171 })
172 } else if let Some(arr) = system_val.as_array() {
173 let texts: Vec<&str> = arr
175 .iter()
176 .filter_map(|block| {
177 let block_type = block.get("type")?.as_str()?;
178 if block_type == "text" {
179 block.get("text")?.as_str()
180 } else {
181 None
182 }
183 })
184 .collect();
185 if texts.is_empty() {
186 None
187 } else {
188 Some(Message::System {
189 content: MessageContent::Text(texts.join("\n")),
190 name: None,
191 })
192 }
193 } else {
194 None
195 }
196}
197
198fn extract_system_text(msg: &Message) -> Option<String> {
200 match msg {
201 Message::System {
202 content: MessageContent::Text(s),
203 ..
204 } => Some(s.clone()),
205 Message::System {
206 content: MessageContent::Parts(parts),
207 ..
208 } => {
209 let texts: Vec<&str> = parts
210 .iter()
211 .filter_map(|p| match p {
212 super::request::ContentPart::Text { text } => Some(text.as_str()),
213 super::request::ContentPart::ImageUrl { .. } => None,
214 })
215 .collect();
216 if texts.is_empty() {
217 None
218 } else {
219 Some(texts.join("\n"))
220 }
221 }
222 _ => None,
223 }
224}
225
226fn split_system_and_messages(messages: &[Message]) -> (Option<String>, Vec<&Message>) {
227 let mut system_text = None;
228 let mut non_system_messages = Vec::new();
229
230 for msg in messages {
231 if let Some(text) = extract_system_text(msg) {
232 system_text = Some(text);
233 } else {
234 non_system_messages.push(msg);
235 }
236 }
237
238 (system_text, non_system_messages)
239}
240
241fn insert_serialized<T: serde::Serialize>(
242 obj: &mut serde_json::Map<String, Json>,
243 key: &str,
244 value: &T,
245 context: &str,
246) -> Result<()> {
247 let json = serde_json::to_value(value)
248 .map_err(|e| FlowError::Internal(format!("Anthropic Messages {context} encode: {e}")))?;
249 obj.insert(key.into(), json);
250 Ok(())
251}
252
253fn overlay_generation_params(obj: &mut serde_json::Map<String, Json>, params: &GenerationParams) {
254 if let Some(temp) = params.temperature {
255 obj.insert("temperature".into(), json_f64(temp));
256 }
257 if let Some(top_p) = params.top_p {
258 obj.insert("top_p".into(), json_f64(top_p));
259 }
260 if let Some(max_tokens) = params.max_tokens {
261 obj.insert("max_tokens".into(), Json::from(max_tokens));
262 }
263}
264
265fn encode_anthropic_tools(tools: &[ToolDefinition]) -> Vec<Json> {
266 tools
267 .iter()
268 .map(|td| {
269 let mut tool = serde_json::Map::new();
270 tool.insert("name".into(), Json::String(td.function.name.clone()));
271 if let Some(ref desc) = td.function.description {
272 tool.insert("description".into(), Json::String(desc.clone()));
273 }
274 if let Some(ref params) = td.function.parameters {
275 tool.insert("input_schema".into(), params.clone());
276 }
277 Json::Object(tool)
278 })
279 .collect()
280}
281
282impl LlmResponseCodec for AnthropicMessagesCodec {
287 fn decode_response(&self, response: &Json) -> Result<AnnotatedLlmResponse> {
288 let raw: RawAnthropicResponse = serde_json::from_value(response.clone())
289 .map_err(|e| FlowError::Internal(format!("Anthropic Messages response decode: {e}")))?;
290
291 let content_blocks = raw.content.as_ref();
293
294 let text_parts: Vec<&str> = content_blocks
296 .map(|blocks| {
297 blocks
298 .iter()
299 .filter_map(|block| {
300 let block_type = block.get("type")?.as_str()?;
301 if block_type == "text" {
302 block.get("text")?.as_str()
303 } else {
304 None
305 }
306 })
307 .collect()
308 })
309 .unwrap_or_default();
310
311 let message = if text_parts.is_empty() {
312 None
313 } else {
314 Some(MessageContent::Text(text_parts.join("\n")))
315 };
316
317 let tool_calls: Vec<ResponseToolCall> = content_blocks
319 .map(|blocks| {
320 blocks
321 .iter()
322 .filter_map(|block| {
323 let block_type = block.get("type")?.as_str()?;
324 if block_type == "tool_use" {
325 let id = block.get("id")?.as_str()?.to_string();
326 let name = block.get("name")?.as_str()?.to_string();
327 let arguments = block.get("input")?.clone();
329 Some(ResponseToolCall {
330 id,
331 name,
332 arguments,
333 })
334 } else {
335 None
336 }
337 })
338 .collect()
339 })
340 .unwrap_or_default();
341
342 let tool_calls = if tool_calls.is_empty() {
343 None
344 } else {
345 Some(tool_calls)
346 };
347
348 let finish_reason = raw.stop_reason.as_deref().map(map_anthropic_stop_reason);
350
351 let usage = raw.usage.map(|u| {
353 let prompt = u.input_tokens;
354 let completion = u.output_tokens;
355 Usage {
356 prompt_tokens: prompt,
357 completion_tokens: completion,
358 total_tokens: match (prompt, completion) {
360 (Some(p), Some(c)) => Some(p + c),
361 _ => None,
362 },
363 cache_read_tokens: u.cache_read_input_tokens,
364 cache_write_tokens: u.cache_creation_input_tokens,
365 }
366 });
367
368 let api_specific_content_blocks = raw.content.clone();
370 let api_specific = Some(ApiSpecificResponse::AnthropicMessages {
371 object_type: raw.object_type,
372 role: raw.role,
373 stop_reason: raw.stop_reason,
374 stop_sequence: raw.stop_sequence,
375 service_tier: raw.service_tier,
376 container: raw.container,
377 content_blocks: api_specific_content_blocks,
378 });
379
380 Ok(AnnotatedLlmResponse {
381 id: raw.id,
382 model: raw.model,
383 message,
384 tool_calls,
385 finish_reason,
386 usage,
387 api_specific,
388 extra: raw.extra,
389 })
390 }
391}
392
393impl LlmCodec for AnthropicMessagesCodec {
398 fn decode(&self, request: &LlmRequest) -> Result<AnnotatedLlmRequest> {
399 let obj = request
400 .content
401 .as_object()
402 .ok_or_else(|| FlowError::Internal("request content is not an object".into()))?;
403
404 let system_msg = obj.get("system").and_then(extract_system_message);
406
407 let mut messages: Vec<Message> = obj
409 .get("messages")
410 .map(|v| serde_json::from_value(v.clone()).unwrap_or_default())
411 .unwrap_or_default();
412
413 if let Some(sys) = system_msg {
415 messages.insert(0, sys);
416 }
417
418 let model = obj.get("model").and_then(|v| v.as_str()).map(String::from);
420
421 let temperature = obj.get("temperature").and_then(|v| v.as_f64());
423 let top_p = obj.get("top_p").and_then(|v| v.as_f64());
424 let max_tokens = obj.get("max_tokens").and_then(|v| v.as_u64());
425 let stop = obj
427 .get("stop_sequences")
428 .and_then(|v| serde_json::from_value::<Vec<String>>(v.clone()).ok());
429
430 let params =
431 if temperature.is_some() || max_tokens.is_some() || top_p.is_some() || stop.is_some() {
432 Some(GenerationParams {
433 temperature,
434 max_tokens,
435 top_p,
436 stop,
437 })
438 } else {
439 None
440 };
441
442 let tools: Option<Vec<ToolDefinition>> = obj.get("tools").and_then(|v| {
445 let arr = v.as_array()?;
446 let defs: Vec<ToolDefinition> = arr
447 .iter()
448 .filter_map(|tool| {
449 let name = tool.get("name")?.as_str()?.to_string();
450 let description = tool
451 .get("description")
452 .and_then(|d| d.as_str())
453 .map(String::from);
454 let parameters = tool.get("input_schema").cloned();
455 Some(ToolDefinition {
456 tool_type: "function".into(),
457 function: FunctionDefinition {
458 name,
459 description,
460 parameters,
461 },
462 })
463 })
464 .collect();
465 if defs.is_empty() { None } else { Some(defs) }
466 });
467
468 let tool_choice = obj
470 .get("tool_choice")
471 .and_then(decode_anthropic_tool_choice);
472 let parallel_tool_calls = obj.get("tool_choice").and_then(decode_parallel_tool_calls);
473
474 let extra: serde_json::Map<String, Json> = obj
476 .iter()
477 .filter(|(k, _)| !MODELED_REQUEST_KEYS.contains(&k.as_str()))
478 .map(|(k, v)| (k.clone(), v.clone()))
479 .collect();
480
481 Ok(AnnotatedLlmRequest {
482 messages,
483 model,
484 params,
485 tools,
486 tool_choice,
487 store: None,
488 previous_response_id: None,
489 truncation: None,
490 reasoning: None,
491 include: None,
492 user: None,
493 metadata: obj.get("metadata").cloned(),
494 service_tier: obj
495 .get("service_tier")
496 .and_then(|v| v.as_str())
497 .map(String::from),
498 parallel_tool_calls,
499 max_output_tokens: None,
500 max_tool_calls: None,
501 top_logprobs: None,
502 stream: None,
503 extra,
504 })
505 }
506
507 fn encode(&self, annotated: &AnnotatedLlmRequest, original: &LlmRequest) -> Result<LlmRequest> {
508 let mut content = original.content.clone();
509 let obj = content
510 .as_object_mut()
511 .ok_or_else(|| FlowError::Internal("original content is not an object".into()))?;
512
513 let (system_text, non_system_messages) = split_system_and_messages(&annotated.messages);
514
515 if let Some(text) = system_text {
516 obj.insert("system".into(), Json::String(text));
517 }
518
519 insert_serialized(obj, "messages", &non_system_messages, "messages")?;
521
522 if let Some(ref model) = annotated.model {
524 obj.insert("model".into(), Json::String(model.clone()));
525 }
526
527 if let Some(ref params) = annotated.params {
529 overlay_generation_params(obj, params);
530 if let Some(ref stop) = params.stop {
532 insert_serialized(obj, "stop_sequences", stop, "stop_sequences")?;
533 }
534 }
535
536 if let Some(ref tools) = annotated.tools {
539 let anthropic_tools = encode_anthropic_tools(tools);
540 insert_serialized(obj, "tools", &anthropic_tools, "tools")?;
541 }
542
543 if let Some(ref tool_choice) = annotated.tool_choice {
545 obj.insert(
546 "tool_choice".into(),
547 encode_tool_choice_with_parallel_hint(tool_choice, annotated.parallel_tool_calls),
548 );
549 }
550
551 if let Some(ref metadata) = annotated.metadata {
552 obj.insert("metadata".into(), metadata.clone());
553 }
554 if let Some(ref service_tier) = annotated.service_tier {
555 obj.insert("service_tier".into(), Json::String(service_tier.clone()));
556 }
557
558 for (k, v) in &annotated.extra {
560 obj.insert(k.clone(), v.clone());
561 }
562
563 Ok(LlmRequest {
564 headers: original.headers.clone(),
565 content,
566 })
567 }
568}
569
570pub struct AnthropicMessagesStreamingCodec {
589 state: std::sync::Arc<std::sync::Mutex<AnthropicMessagesStreamingState>>,
590}
591
592impl AnthropicMessagesStreamingCodec {
593 pub fn new() -> Self {
595 Self {
596 state: std::sync::Arc::new(std::sync::Mutex::new(
597 AnthropicMessagesStreamingState::default(),
598 )),
599 }
600 }
601}
602
603impl Default for AnthropicMessagesStreamingCodec {
604 fn default() -> Self {
605 Self::new()
606 }
607}
608
609impl super::streaming::StreamingCodec for AnthropicMessagesStreamingCodec {
610 fn collector(&self) -> crate::api::runtime::LlmCollectorFn {
611 let state = std::sync::Arc::clone(&self.state);
612 Box::new(move |event: Json| -> Result<()> {
613 let mut guard = state
614 .lock()
615 .unwrap_or_else(|poisoned| poisoned.into_inner());
616 guard.observe(&event);
617 Ok(())
618 })
619 }
620
621 fn finalizer(&self) -> crate::api::runtime::LlmFinalizerFn {
622 let state = std::sync::Arc::clone(&self.state);
623 Box::new(move || -> Json {
624 let mut guard = state
625 .lock()
626 .unwrap_or_else(|poisoned| poisoned.into_inner());
627 std::mem::take(&mut *guard).finalize()
630 })
631 }
632}
633
634#[derive(Debug, Default)]
635struct AnthropicMessagesStreamingState {
636 id: Option<String>,
637 type_: Option<String>,
638 role: Option<String>,
639 model: Option<String>,
640 usage: Option<Json>,
643 stop_reason: Option<String>,
644 stop_sequence: Option<Json>,
646 blocks: Vec<Option<StreamingBlock>>,
649}
650
651#[derive(Debug, Default, Clone)]
652struct StreamingBlock {
653 skeleton: serde_json::Map<String, Json>,
657 text: String,
658 has_text: bool,
659 partial_json: String,
660 has_partial_json: bool,
661 citations: Vec<Json>,
662 has_citations: bool,
663}
664
665impl AnthropicMessagesStreamingState {
666 fn observe(&mut self, event: &Json) {
667 let event_type = event.get("type").and_then(Json::as_str).unwrap_or("");
668 match event_type {
669 "message_start" => self.observe_message_start(event),
670 "content_block_start" => self.observe_content_block_start(event),
671 "content_block_delta" => self.observe_content_block_delta(event),
672 "message_delta" => self.observe_message_delta(event),
673 _ => {}
677 }
678 }
679
680 fn observe_message_start(&mut self, event: &Json) {
681 let Some(message) = event.get("message") else {
682 return;
683 };
684 if let Some(id) = message.get("id").and_then(Json::as_str) {
685 self.id = Some(id.to_string());
686 }
687 if let Some(model) = message.get("model").and_then(Json::as_str) {
688 self.model = Some(model.to_string());
689 }
690 if let Some(role) = message.get("role").and_then(Json::as_str) {
691 self.role = Some(role.to_string());
692 }
693 if let Some(t) = message.get("type").and_then(Json::as_str) {
694 self.type_ = Some(t.to_string());
695 }
696 if let Some(usage) = message.get("usage") {
697 self.usage = Some(usage.clone());
698 }
699 }
700
701 fn observe_content_block_start(&mut self, event: &Json) {
702 let Some(index) = event.get("index").and_then(Json::as_u64) else {
703 return;
704 };
705 let Some(content_block) = event.get("content_block") else {
706 return;
707 };
708 let skeleton = match content_block {
709 Json::Object(map) => map.clone(),
710 _ => return,
711 };
712 let index = index as usize;
713 while self.blocks.len() <= index {
714 self.blocks.push(None);
715 }
716 self.blocks[index] = Some(StreamingBlock {
717 skeleton,
718 ..StreamingBlock::default()
719 });
720 }
721
722 fn observe_content_block_delta(&mut self, event: &Json) {
723 let Some(index) = event.get("index").and_then(Json::as_u64) else {
724 return;
725 };
726 let index = index as usize;
727 let Some(delta) = event.get("delta") else {
728 return;
729 };
730 let delta_type = delta.get("type").and_then(Json::as_str).unwrap_or("");
731 let Some(slot) = self.blocks.get_mut(index) else {
732 return;
733 };
734 let Some(block) = slot.as_mut() else { return };
735 match delta_type {
736 "text_delta" => {
737 if let Some(text) = delta.get("text").and_then(Json::as_str) {
738 block.text.push_str(text);
739 block.has_text = true;
740 }
741 }
742 "input_json_delta" => {
743 if let Some(partial) = delta.get("partial_json").and_then(Json::as_str) {
744 block.partial_json.push_str(partial);
745 block.has_partial_json = true;
746 }
747 }
748 "citations_delta" => {
749 if let Some(citation) = delta.get("citation") {
750 block.citations.push(citation.clone());
751 block.has_citations = true;
752 }
753 }
754 _ => {}
757 }
758 }
759
760 fn observe_message_delta(&mut self, event: &Json) {
761 if let Some(delta) = event.get("delta") {
762 if let Some(reason) = delta.get("stop_reason").and_then(Json::as_str) {
763 self.stop_reason = Some(reason.to_string());
764 }
765 if let Some(seq) = delta.get("stop_sequence") {
766 self.stop_sequence = Some(seq.clone());
767 }
768 }
769 if let Some(usage) = event.get("usage") {
770 self.usage = Some(usage.clone());
771 }
772 }
773
774 fn finalize(self) -> Json {
775 let mut output = serde_json::Map::new();
776 if let Some(id) = self.id {
777 output.insert("id".to_string(), Json::String(id));
778 }
779 if let Some(t) = self.type_ {
780 output.insert("type".to_string(), Json::String(t));
781 }
782 if let Some(role) = self.role {
783 output.insert("role".to_string(), Json::String(role));
784 }
785 if let Some(model) = self.model {
786 output.insert("model".to_string(), Json::String(model));
787 }
788 let content: Vec<Json> = self
789 .blocks
790 .into_iter()
791 .filter_map(|block| block.map(StreamingBlock::finalize))
792 .collect();
793 output.insert("content".to_string(), Json::Array(content));
794 if let Some(reason) = self.stop_reason {
795 output.insert("stop_reason".to_string(), Json::String(reason));
796 }
797 if let Some(seq) = self.stop_sequence {
798 output.insert("stop_sequence".to_string(), seq);
799 }
800 if let Some(usage) = self.usage {
801 output.insert("usage".to_string(), usage);
802 }
803 Json::Object(output)
804 }
805}
806
807impl StreamingBlock {
808 fn finalize(mut self) -> Json {
809 if self.has_text {
810 self.skeleton
811 .insert("text".to_string(), Json::String(self.text));
812 }
813 if self.has_partial_json {
814 let parsed = match serde_json::from_str::<Json>(&self.partial_json) {
819 Ok(value) => value,
820 Err(_) => Json::String(self.partial_json),
821 };
822 self.skeleton.insert("input".to_string(), parsed);
823 }
824 if self.has_citations {
825 self.skeleton
826 .insert("citations".to_string(), Json::Array(self.citations));
827 }
828 Json::Object(self.skeleton)
829 }
830}
831
832#[cfg(test)]
837#[path = "../../tests/unit/codec/anthropic_tests.rs"]
838mod tests;