1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use serde_json::Value as JsonValue;
4use std::collections::HashMap;
5use tera::{Context as TeraContext, Filter, Tera, Value};
6
7use crate::provider::{ChatRequest, ContentPart, Message, MessageContent};
8
9#[derive(Clone)]
11pub struct TemplateProcessor {
12 tera: Tera,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct EndpointTemplates {
18 #[serde(default)]
20 pub template: Option<TemplateConfig>,
21
22 #[serde(default)]
24 pub model_templates: HashMap<String, TemplateConfig>,
25
26 #[serde(default)]
28 pub model_template_patterns: HashMap<String, TemplateConfig>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct TemplateConfig {
34 pub request: Option<String>,
36 pub response: Option<String>,
38 pub stream_response: Option<String>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ModelEndpointTemplates {
45 #[serde(default)]
46 pub chat: Option<TemplateConfig>,
47 #[serde(default)]
48 pub images: Option<TemplateConfig>,
49 #[serde(default)]
50 pub embeddings: Option<TemplateConfig>,
51}
52
53impl EndpointTemplates {
54 #[allow(dead_code)]
56 pub fn get_template_for_model(&self, model_name: &str, template_type: &str) -> Option<String> {
57 if let Some(template) = self.model_templates.get(model_name) {
59 return match template_type {
60 "request" => template.request.clone(),
61 "response" => template.response.clone(),
62 "stream_response" => template.stream_response.clone(),
63 _ => None,
64 };
65 }
66
67 for (pattern, template) in &self.model_template_patterns {
69 if let Ok(re) = regex::Regex::new(pattern) {
70 if re.is_match(model_name) {
71 return match template_type {
72 "request" => template.request.clone(),
73 "response" => template.response.clone(),
74 "stream_response" => template.stream_response.clone(),
75 _ => None,
76 };
77 }
78 }
79 }
80
81 if let Some(template) = &self.template {
83 return match template_type {
84 "request" => template.request.clone(),
85 "response" => template.response.clone(),
86 "stream_response" => template.stream_response.clone(),
87 _ => None,
88 };
89 }
90
91 None
92 }
93}
94
95impl TemplateProcessor {
96 pub fn new() -> Result<Self> {
98 let mut tera = Tera::default();
99
100 tera.register_filter("json", JsonFilter);
102 tera.register_filter("gemini_role", GeminiRoleFilter);
103 tera.register_filter("system_to_user_role", SystemToUserRoleFilter);
104 tera.register_filter("default", DefaultFilter);
105 tera.register_filter("select_tool_calls", SelectToolCallsFilter);
106 tera.register_filter("from_json", FromJsonFilter);
107 tera.register_filter("selectattr", SelectAttrFilter);
108 tera.register_filter("base_messages", BaseMessagesFilter);
109 tera.register_filter("anthropic_messages", AnthropicMessagesFilter);
110 tera.register_filter("gemini_messages", GeminiMessagesFilter);
111
112 Ok(Self { tera })
113 }
114
115 #[allow(dead_code)]
117 pub fn render_template(
118 &mut self,
119 name: &str,
120 template: &str,
121 context: &TeraContext,
122 ) -> Result<String> {
123 self.tera.add_raw_template(name, template)?;
124 Ok(self.tera.render(name, context)?)
125 }
126
127 pub fn process_request(
129 &mut self,
130 request: &ChatRequest,
131 template: &str,
132 provider_vars: &HashMap<String, String>,
133 ) -> Result<JsonValue> {
134 self.tera
136 .add_raw_template("request", template)
137 .context("Failed to parse request template")?;
138
139 let mut context = TeraContext::new();
141
142 context.insert("model", &request.model);
144 context.insert("max_tokens", &request.max_tokens);
145 context.insert("temperature", &request.temperature);
146 context.insert("stream", &request.stream);
147 context.insert("tools", &request.tools);
148
149 let processed_messages = self.process_messages(&request.messages)?;
151 context.insert("messages", &processed_messages);
152
153 if let Some(system_msg) = request.messages.iter().find(|m| m.role == "system") {
155 if let Some(content) = system_msg.get_text_content() {
156 context.insert("system_prompt", content);
157 }
158 }
159
160 for (key, value) in provider_vars {
162 context.insert(key, value);
163 }
164
165 let rendered = self
167 .tera
168 .render("request", &context)
169 .context("Failed to render request template")?;
170
171 let json_value: JsonValue =
173 serde_json::from_str(&rendered).context("Template did not produce valid JSON")?;
174
175 Ok(json_value)
176 }
177
178 pub fn process_image_request(
180 &mut self,
181 request: &crate::provider::ImageGenerationRequest,
182 template: &str,
183 provider_vars: &HashMap<String, String>,
184 ) -> Result<JsonValue> {
185 self.tera
187 .add_raw_template("image_request", template)
188 .context("Failed to parse image request template")?;
189
190 let mut context = TeraContext::new();
192
193 context.insert("prompt", &request.prompt);
195 context.insert("model", &request.model);
196 context.insert("n", &request.n);
197 context.insert("size", &request.size);
198 context.insert("quality", &request.quality);
199 context.insert("style", &request.style);
200 context.insert("response_format", &request.response_format);
201
202 for (key, value) in provider_vars {
204 context.insert(key, value);
205 }
206
207 let rendered = self
209 .tera
210 .render("image_request", &context)
211 .context("Failed to render image request template")?;
212
213 let json_value: JsonValue =
215 serde_json::from_str(&rendered).context("Image template did not produce valid JSON")?;
216
217 Ok(json_value)
218 }
219
220 #[allow(dead_code)]
222 pub fn process_audio_request(
223 &mut self,
224 request: &crate::provider::AudioTranscriptionRequest,
225 template: &str,
226 provider_vars: &HashMap<String, String>,
227 ) -> Result<JsonValue> {
228 self.tera
230 .add_raw_template("audio_request", template)
231 .context("Failed to parse audio request template")?;
232
233 let mut context = TeraContext::new();
235
236 context.insert("file", &request.file);
238 context.insert("model", &request.model);
239 context.insert("language", &request.language);
240 context.insert("prompt", &request.prompt);
241 context.insert("response_format", &request.response_format);
242 context.insert("temperature", &request.temperature);
243
244 for (key, value) in provider_vars {
246 context.insert(key, value);
247 }
248
249 let rendered = self
251 .tera
252 .render("audio_request", &context)
253 .context("Failed to render audio request template")?;
254
255 let json_value: JsonValue =
257 serde_json::from_str(&rendered).context("Audio template did not produce valid JSON")?;
258
259 Ok(json_value)
260 }
261
262 pub fn process_speech_request(
264 &mut self,
265 request: &crate::provider::AudioSpeechRequest,
266 template: &str,
267 provider_vars: &HashMap<String, String>,
268 ) -> Result<JsonValue> {
269 self.tera
271 .add_raw_template("speech_request", template)
272 .context("Failed to parse speech request template")?;
273
274 let mut context = TeraContext::new();
276
277 context.insert("model", &request.model);
279 context.insert("input", &request.input);
280 context.insert("voice", &request.voice);
281 context.insert("response_format", &request.response_format);
282 context.insert("speed", &request.speed);
283
284 for (key, value) in provider_vars {
286 context.insert(key, value);
287 }
288
289 let rendered = self
291 .tera
292 .render("speech_request", &context)
293 .context("Failed to render speech request template")?;
294
295 let json_value: JsonValue = serde_json::from_str(&rendered)
297 .context("Speech template did not produce valid JSON")?;
298
299 Ok(json_value)
300 }
301
302 pub fn process_embeddings_request(
304 &mut self,
305 request: &crate::provider::EmbeddingRequest,
306 template: &str,
307 provider_vars: &HashMap<String, String>,
308 ) -> Result<JsonValue> {
309 self.tera
311 .add_raw_template("embeddings_request", template)
312 .context("Failed to parse embeddings request template")?;
313
314 let mut context = TeraContext::new();
316
317 context.insert("model", &request.model);
319 context.insert("input", &request.input);
320 context.insert("encoding_format", &request.encoding_format);
321
322 for (key, value) in provider_vars {
324 context.insert(key, value);
325 }
326
327 let rendered = self
329 .tera
330 .render("embeddings_request", &context)
331 .context("Failed to render embeddings request template")?;
332
333 let json_value: JsonValue = serde_json::from_str(&rendered)
335 .context("Embeddings template did not produce valid JSON")?;
336
337 Ok(json_value)
338 }
339
340 pub fn process_response(&mut self, response: &JsonValue, template: &str) -> Result<JsonValue> {
342 self.tera
344 .add_raw_template("response", template)
345 .context("Failed to parse response template")?;
346
347 let context = TeraContext::from_serialize(response)
349 .context("Failed to serialize response to context")?;
350
351 let rendered = self
353 .tera
354 .render("response", &context)
355 .context("Failed to render response template")?;
356
357 let json_value: JsonValue = serde_json::from_str(&rendered)
359 .context("Response template did not produce valid JSON")?;
360
361 Ok(json_value)
362 }
363
364 fn process_messages(&self, messages: &[Message]) -> Result<Vec<ProcessedMessage>> {
366 let mut processed = Vec::new();
367
368 for message in messages {
369 let mut proc_msg = ProcessedMessage {
370 role: message.role.clone(),
371 content: None,
372 images: Vec::new(),
373 tool_calls: message.tool_calls.clone(),
374 tool_call_id: message.tool_call_id.clone(),
375 };
376
377 match &message.content_type {
378 MessageContent::Text { content } => {
379 proc_msg.content = content.clone();
380 }
381 MessageContent::Multimodal { content } => {
382 for part in content {
383 match part {
384 ContentPart::Text { text } => {
385 proc_msg.content = Some(text.clone());
386 }
387 ContentPart::ImageUrl { image_url } => {
388 if let Some(data_url) = image_url.url.strip_prefix("data:") {
390 if let Some(comma_pos) = data_url.find(',') {
391 let header = &data_url[..comma_pos];
392 let data = &data_url[comma_pos + 1..];
393
394 let mime_type = if let Some(semi_pos) = header.find(';') {
395 header[..semi_pos].to_string()
396 } else {
397 header.to_string()
398 };
399
400 proc_msg.images.push(ProcessedImage {
401 mime_type,
402 data: data.to_string(),
403 url: image_url.url.clone(),
404 });
405 }
406 } else {
407 proc_msg.images.push(ProcessedImage {
409 mime_type: "image/jpeg".to_string(), data: String::new(),
411 url: image_url.url.clone(),
412 });
413 }
414 }
415 }
416 }
417 }
418 }
419
420 processed.push(proc_msg);
421 }
422
423 Ok(processed)
424 }
425}
426
427#[derive(Debug, Serialize)]
429struct ProcessedMessage {
430 role: String,
431 content: Option<String>,
432 images: Vec<ProcessedImage>,
433 tool_calls: Option<Vec<crate::provider::ToolCall>>,
434 tool_call_id: Option<String>,
435}
436
437#[derive(Debug, Serialize)]
438struct ProcessedImage {
439 mime_type: String,
440 data: String,
441 url: String,
442}
443
444struct JsonFilter;
446
447impl Filter for JsonFilter {
448 fn filter(&self, value: &Value, _args: &HashMap<String, Value>) -> tera::Result<Value> {
449 match serde_json::to_string(&value) {
450 Ok(json_str) => Ok(Value::String(json_str)),
451 Err(e) => Err(tera::Error::msg(format!(
452 "Failed to serialize to JSON: {}",
453 e
454 ))),
455 }
456 }
457}
458
459struct GeminiRoleFilter;
461
462impl Filter for GeminiRoleFilter {
463 fn filter(&self, value: &Value, _args: &HashMap<String, Value>) -> tera::Result<Value> {
464 match value.as_str() {
465 Some("user") => Ok(Value::String("user".to_string())),
466 Some("assistant") => Ok(Value::String("model".to_string())),
467 Some("system") => Ok(Value::String("user".to_string())), Some(other) => Ok(Value::String(other.to_string())),
469 None => Ok(value.clone()),
470 }
471 }
472}
473
474struct SystemToUserRoleFilter;
476
477impl Filter for SystemToUserRoleFilter {
478 fn filter(&self, value: &Value, _args: &HashMap<String, Value>) -> tera::Result<Value> {
479 match value.as_str() {
480 Some("system") => Ok(Value::String("user".to_string())), Some(other) => Ok(Value::String(other.to_string())),
482 None => Ok(value.clone()),
483 }
484 }
485}
486
487struct DefaultFilter;
489
490impl Filter for DefaultFilter {
491 fn filter(&self, value: &Value, args: &HashMap<String, Value>) -> tera::Result<Value> {
492 if value.is_null() || (value.is_string() && value.as_str() == Some("")) {
493 if let Some(default_value) = args.get("value") {
494 Ok(default_value.clone())
495 } else {
496 Ok(Value::Null)
497 }
498 } else {
499 Ok(value.clone())
500 }
501 }
502}
503
504struct SelectToolCallsFilter;
506
507impl Filter for SelectToolCallsFilter {
508 fn filter(&self, value: &Value, args: &HashMap<String, Value>) -> tera::Result<Value> {
509 if let Some(array) = value.as_array() {
510 let key = args
511 .get("key")
512 .and_then(|v| v.as_str())
513 .unwrap_or("functionCall");
514
515 let filtered: Vec<Value> = array
516 .iter()
517 .filter(|item| {
518 item.as_object()
519 .map(|obj| obj.contains_key(key))
520 .unwrap_or(false)
521 })
522 .cloned()
523 .collect();
524
525 Ok(Value::Array(filtered))
526 } else {
527 Ok(Value::Array(vec![]))
528 }
529 }
530}
531
532struct FromJsonFilter;
534
535impl Filter for FromJsonFilter {
536 fn filter(&self, value: &Value, _args: &HashMap<String, Value>) -> tera::Result<Value> {
537 if let Some(json_str) = value.as_str() {
538 match serde_json::from_str::<JsonValue>(json_str) {
539 Ok(parsed) => {
540 match serde_json::to_value(&parsed) {
542 Ok(tera_value) => Ok(tera_value),
543 Err(e) => Err(tera::Error::msg(format!(
544 "Failed to convert to Tera value: {}",
545 e
546 ))),
547 }
548 }
549 Err(e) => Err(tera::Error::msg(format!("Failed to parse JSON: {}", e))),
550 }
551 } else {
552 Ok(value.clone())
553 }
554 }
555}
556
557struct SelectAttrFilter;
559
560impl Filter for SelectAttrFilter {
561 fn filter(&self, value: &Value, args: &HashMap<String, Value>) -> tera::Result<Value> {
562 if let Some(array) = value.as_array() {
563 let attr_name = args
564 .get("attr")
565 .and_then(|v| v.as_str())
566 .ok_or_else(|| tera::Error::msg("selectattr filter requires 'attr' argument"))?;
567
568 let test_value = args
569 .get("value")
570 .ok_or_else(|| tera::Error::msg("selectattr filter requires 'value' argument"))?;
571
572 let filtered: Vec<Value> = array
573 .iter()
574 .filter(|item| {
575 if let Some(obj) = item.as_object() {
576 if let Some(attr_value) = obj.get(attr_name) {
577 attr_value == test_value
578 } else {
579 false
580 }
581 } else {
582 false
583 }
584 })
585 .cloned()
586 .collect();
587
588 Ok(Value::Array(filtered))
589 } else {
590 Ok(Value::Array(vec![]))
591 }
592 }
593}
594
595struct BaseMessagesFilter;
597
598impl Filter for BaseMessagesFilter {
599 fn filter(&self, value: &Value, _args: &HashMap<String, Value>) -> tera::Result<Value> {
600 if let Some(array) = value.as_array() {
601 let cleaned: Vec<Value> = array
602 .iter()
603 .map(|item| {
604 if let Some(obj) = item.as_object() {
605 let mut cleaned_obj = serde_json::Map::new();
606
607 for (key, value) in obj {
609 match key.as_str() {
610 "role" | "content" => {
611 cleaned_obj.insert(key.clone(), value.clone());
613 }
614 "tool_calls" => {
615 if !value.is_null()
617 && value.as_array().map_or(true, |arr| !arr.is_empty())
618 {
619 cleaned_obj.insert(key.clone(), value.clone());
620 }
621 }
622 "tool_call_id" => {
623 if !value.is_null()
625 && value.as_str().map_or(false, |s| !s.is_empty())
626 {
627 cleaned_obj.insert(key.clone(), value.clone());
628 }
629 }
630 _ => {}
632 }
633 }
634
635 Value::Object(cleaned_obj)
636 } else {
637 item.clone()
638 }
639 })
640 .collect();
641
642 Ok(Value::Array(cleaned))
643 } else {
644 Ok(value.clone())
645 }
646 }
647}
648
649struct AnthropicMessagesFilter;
651
652impl Filter for AnthropicMessagesFilter {
653 fn filter(&self, value: &Value, _args: &HashMap<String, Value>) -> tera::Result<Value> {
654 if let Some(array) = value.as_array() {
655 let converted: Vec<Value> = array
656 .iter()
657 .map(|item| {
658 if let Some(obj) = item.as_object() {
659 let mut anthropic_msg = serde_json::Map::new();
660
661 if let Some(role) = obj.get("role") {
663 anthropic_msg.insert("role".to_string(), role.clone());
664 }
665
666 let mut content_parts = Vec::new();
668
669 if let Some(text_content) = obj.get("content") {
671 if !text_content.is_null()
672 && text_content.as_str().map_or(false, |s| !s.is_empty())
673 {
674 let text_part = serde_json::json!({
675 "type": "text",
676 "text": text_content
677 });
678 content_parts.push(text_part);
679 }
680 }
681
682 if let Some(images) = obj.get("images") {
684 if let Some(images_array) = images.as_array() {
685 for image in images_array {
686 if let Some(image_obj) = image.as_object() {
687 if let (Some(data), Some(mime_type)) = (
688 image_obj.get("data").and_then(|v| v.as_str()),
689 image_obj.get("mime_type").and_then(|v| v.as_str()),
690 ) {
691 if !data.is_empty() {
692 let image_part = serde_json::json!({
694 "type": "image",
695 "source": {
696 "type": "base64",
697 "media_type": mime_type,
698 "data": data
699 }
700 });
701 content_parts.push(image_part);
702 }
703 } else if let Some(url) =
704 image_obj.get("url").and_then(|v| v.as_str())
705 {
706 if !url.starts_with("data:") && !url.is_empty() {
707 let image_part = serde_json::json!({
709 "type": "image",
710 "source": {
711 "type": "url",
712 "url": url
713 }
714 });
715 content_parts.push(image_part);
716 }
717 }
718 }
719 }
720 }
721 }
722
723 if content_parts.len() > 1
725 || (content_parts.len() == 1
726 && content_parts[0].get("type")
727 == Some(&serde_json::Value::String("image".to_string())))
728 {
729 anthropic_msg.insert(
730 "content".to_string(),
731 serde_json::Value::Array(content_parts),
732 );
733 } else if let Some(first_part) = content_parts.first() {
734 if let Some(text) = first_part.get("text") {
735 anthropic_msg.insert("content".to_string(), text.clone());
736 }
737 }
738
739 if let Some(tool_calls) = obj.get("tool_calls") {
741 if !tool_calls.is_null()
742 && tool_calls.as_array().map_or(true, |arr| !arr.is_empty())
743 {
744 anthropic_msg.insert("tool_calls".to_string(), tool_calls.clone());
745 }
746 }
747
748 if let Some(tool_call_id) = obj.get("tool_call_id") {
750 if !tool_call_id.is_null()
751 && tool_call_id.as_str().map_or(false, |s| !s.is_empty())
752 {
753 anthropic_msg
754 .insert("tool_call_id".to_string(), tool_call_id.clone());
755 }
756 }
757
758 Value::Object(anthropic_msg)
759 } else {
760 item.clone()
761 }
762 })
763 .collect();
764
765 Ok(Value::Array(converted))
766 } else {
767 Ok(value.clone())
768 }
769 }
770}
771
772struct GeminiMessagesFilter;
774
775impl Filter for GeminiMessagesFilter {
776 fn filter(&self, value: &Value, _args: &HashMap<String, Value>) -> tera::Result<Value> {
777 if let Some(array) = value.as_array() {
778 let converted: Vec<Value> = array
779 .iter()
780 .map(|item| {
781 if let Some(obj) = item.as_object() {
782 let mut gemini_msg = serde_json::Map::new();
783
784 if let Some(role) = obj.get("role").and_then(|v| v.as_str()) {
786 let gemini_role = match role {
787 "assistant" => "model",
788 "system" => "user", other => other,
790 };
791 gemini_msg.insert(
792 "role".to_string(),
793 serde_json::Value::String(gemini_role.to_string()),
794 );
795 }
796
797 let mut parts = Vec::new();
799
800 if let Some(text_content) = obj.get("content") {
802 if !text_content.is_null()
803 && text_content.as_str().map_or(false, |s| !s.is_empty())
804 {
805 let text_part = serde_json::json!({
806 "text": text_content
807 });
808 parts.push(text_part);
809 }
810 }
811
812 if let Some(images) = obj.get("images") {
814 if let Some(images_array) = images.as_array() {
815 for image in images_array {
816 if let Some(image_obj) = image.as_object() {
817 if let (Some(data), Some(mime_type)) = (
818 image_obj.get("data").and_then(|v| v.as_str()),
819 image_obj.get("mime_type").and_then(|v| v.as_str()),
820 ) {
821 if !data.is_empty() {
822 let image_part = serde_json::json!({
824 "inlineData": {
825 "mimeType": mime_type,
826 "data": data
827 }
828 });
829 parts.push(image_part);
830 }
831 }
832 }
833 }
834 }
835 }
836
837 gemini_msg.insert("parts".to_string(), serde_json::Value::Array(parts));
839
840 if let Some(tool_calls) = obj.get("tool_calls") {
842 if !tool_calls.is_null()
843 && tool_calls.as_array().map_or(true, |arr| !arr.is_empty())
844 {
845 gemini_msg.insert("tool_calls".to_string(), tool_calls.clone());
846 }
847 }
848
849 if let Some(tool_call_id) = obj.get("tool_call_id") {
851 if !tool_call_id.is_null()
852 && tool_call_id.as_str().map_or(false, |s| !s.is_empty())
853 {
854 gemini_msg.insert("tool_call_id".to_string(), tool_call_id.clone());
855 }
856 }
857
858 Value::Object(gemini_msg)
859 } else {
860 item.clone()
861 }
862 })
863 .collect();
864
865 Ok(Value::Array(converted))
866 } else {
867 Ok(value.clone())
868 }
869 }
870}
871
872#[cfg(test)]
873mod tests {
874 use super::*;
875
876 #[test]
877 fn test_json_filter() {
878 let filter = JsonFilter;
879 let value = Value::String("test".to_string());
880 let args = HashMap::new();
881
882 let result = filter.filter(&value, &args).unwrap();
883 assert_eq!(result, Value::String("\"test\"".to_string()));
884 }
885
886 #[test]
887 fn test_gemini_role_filter() {
888 let filter = GeminiRoleFilter;
889 let args = HashMap::new();
890
891 let value = Value::String("assistant".to_string());
892 let result = filter.filter(&value, &args).unwrap();
893 assert_eq!(result, Value::String("model".to_string()));
894
895 let value = Value::String("system".to_string());
896 let result = filter.filter(&value, &args).unwrap();
897 assert_eq!(result, Value::String("user".to_string()));
898 }
899
900 #[test]
901 fn test_default_filter() {
902 let filter = DefaultFilter;
903 let mut args = HashMap::new();
904 args.insert("value".to_string(), Value::String("default".to_string()));
905
906 let value = Value::Null;
907 let result = filter.filter(&value, &args).unwrap();
908 assert_eq!(result, Value::String("default".to_string()));
909
910 let value = Value::String("existing".to_string());
911 let result = filter.filter(&value, &args).unwrap();
912 assert_eq!(result, Value::String("existing".to_string()));
913 }
914}