1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use serde_json::Value as JsonValue;
4use std::collections::HashMap;
5use tera::{Context as TeraContext, Filter, Tera, Value};
6
7use crate::provider::{ChatRequest, Message, MessageContent, ContentPart};
8
9#[derive(Clone)]
11pub struct TemplateProcessor {
12 tera: Tera,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct EndpointTemplates {
18 #[serde(default)]
20 pub template: Option<TemplateConfig>,
21
22 #[serde(default)]
24 pub model_templates: HashMap<String, TemplateConfig>,
25
26 #[serde(default)]
28 pub model_template_patterns: HashMap<String, TemplateConfig>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct TemplateConfig {
34 pub request: Option<String>,
36 pub response: Option<String>,
38 pub stream_response: Option<String>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ModelEndpointTemplates {
45 #[serde(default)]
46 pub chat: Option<TemplateConfig>,
47 #[serde(default)]
48 pub images: Option<TemplateConfig>,
49 #[serde(default)]
50 pub embeddings: Option<TemplateConfig>,
51}
52
53impl EndpointTemplates {
54 #[allow(dead_code)]
56 pub fn get_template_for_model(&self, model_name: &str, template_type: &str) -> Option<String> {
57 if let Some(template) = self.model_templates.get(model_name) {
59 return match template_type {
60 "request" => template.request.clone(),
61 "response" => template.response.clone(),
62 "stream_response" => template.stream_response.clone(),
63 _ => None,
64 };
65 }
66
67 for (pattern, template) in &self.model_template_patterns {
69 if let Ok(re) = regex::Regex::new(pattern) {
70 if re.is_match(model_name) {
71 return match template_type {
72 "request" => template.request.clone(),
73 "response" => template.response.clone(),
74 "stream_response" => template.stream_response.clone(),
75 _ => None,
76 };
77 }
78 }
79 }
80
81 if let Some(template) = &self.template {
83 return match template_type {
84 "request" => template.request.clone(),
85 "response" => template.response.clone(),
86 "stream_response" => template.stream_response.clone(),
87 _ => None,
88 };
89 }
90
91 None
92 }
93}
94
95impl TemplateProcessor {
96 pub fn new() -> Result<Self> {
98 let mut tera = Tera::default();
99
100 tera.register_filter("json", JsonFilter);
102 tera.register_filter("gemini_role", GeminiRoleFilter);
103 tera.register_filter("system_to_user_role", SystemToUserRoleFilter);
104 tera.register_filter("default", DefaultFilter);
105 tera.register_filter("select_tool_calls", SelectToolCallsFilter);
106 tera.register_filter("from_json", FromJsonFilter);
107 tera.register_filter("selectattr", SelectAttrFilter);
108
109 Ok(Self { tera })
110 }
111
112 #[allow(dead_code)]
114 pub fn render_template(&mut self, name: &str, template: &str, context: &TeraContext) -> Result<String> {
115 self.tera.add_raw_template(name, template)?;
116 Ok(self.tera.render(name, context)?)
117 }
118
119 pub fn process_request(
121 &mut self,
122 request: &ChatRequest,
123 template: &str,
124 provider_vars: &HashMap<String, String>,
125 ) -> Result<JsonValue> {
126 self.tera
128 .add_raw_template("request", template)
129 .context("Failed to parse request template")?;
130
131 let mut context = TeraContext::new();
133
134 context.insert("model", &request.model);
136 context.insert("max_tokens", &request.max_tokens);
137 context.insert("temperature", &request.temperature);
138 context.insert("stream", &request.stream);
139 context.insert("tools", &request.tools);
140
141 let processed_messages = self.process_messages(&request.messages)?;
143 context.insert("messages", &processed_messages);
144
145 if let Some(system_msg) = request.messages.iter().find(|m| m.role == "system") {
147 if let Some(content) = system_msg.get_text_content() {
148 context.insert("system_prompt", content);
149 }
150 }
151
152 for (key, value) in provider_vars {
154 context.insert(key, value);
155 }
156
157 let rendered = self.tera
159 .render("request", &context)
160 .context("Failed to render request template")?;
161
162 let json_value: JsonValue = serde_json::from_str(&rendered)
164 .context("Template did not produce valid JSON")?;
165
166 Ok(json_value)
167 }
168
169 pub fn process_image_request(
171 &mut self,
172 request: &crate::provider::ImageGenerationRequest,
173 template: &str,
174 provider_vars: &HashMap<String, String>,
175 ) -> Result<JsonValue> {
176 self.tera
178 .add_raw_template("image_request", template)
179 .context("Failed to parse image request template")?;
180
181 let mut context = TeraContext::new();
183
184 context.insert("prompt", &request.prompt);
186 context.insert("model", &request.model);
187 context.insert("n", &request.n);
188 context.insert("size", &request.size);
189 context.insert("quality", &request.quality);
190 context.insert("style", &request.style);
191 context.insert("response_format", &request.response_format);
192
193 for (key, value) in provider_vars {
195 context.insert(key, value);
196 }
197
198 let rendered = self.tera
200 .render("image_request", &context)
201 .context("Failed to render image request template")?;
202
203 let json_value: JsonValue = serde_json::from_str(&rendered)
205 .context("Image template did not produce valid JSON")?;
206
207 Ok(json_value)
208 }
209
210 pub fn process_response(
212 &mut self,
213 response: &JsonValue,
214 template: &str,
215 ) -> Result<JsonValue> {
216 self.tera
218 .add_raw_template("response", template)
219 .context("Failed to parse response template")?;
220
221 let context = TeraContext::from_serialize(response)
223 .context("Failed to serialize response to context")?;
224
225 let rendered = self.tera
227 .render("response", &context)
228 .context("Failed to render response template")?;
229
230 let json_value: JsonValue = serde_json::from_str(&rendered)
232 .context("Response template did not produce valid JSON")?;
233
234 Ok(json_value)
235 }
236
237 fn process_messages(&self, messages: &[Message]) -> Result<Vec<ProcessedMessage>> {
239 let mut processed = Vec::new();
240
241 for message in messages {
242 let mut proc_msg = ProcessedMessage {
243 role: message.role.clone(),
244 content: None,
245 images: Vec::new(),
246 tool_calls: message.tool_calls.clone(),
247 tool_call_id: message.tool_call_id.clone(),
248 };
249
250 match &message.content_type {
251 MessageContent::Text { content } => {
252 proc_msg.content = content.clone();
253 }
254 MessageContent::Multimodal { content } => {
255 for part in content {
256 match part {
257 ContentPart::Text { text } => {
258 proc_msg.content = Some(text.clone());
259 }
260 ContentPart::ImageUrl { image_url } => {
261 if let Some(data_url) = image_url.url.strip_prefix("data:") {
263 if let Some(comma_pos) = data_url.find(',') {
264 let header = &data_url[..comma_pos];
265 let data = &data_url[comma_pos + 1..];
266
267 let mime_type = if let Some(semi_pos) = header.find(';') {
268 header[..semi_pos].to_string()
269 } else {
270 header.to_string()
271 };
272
273 proc_msg.images.push(ProcessedImage {
274 mime_type,
275 data: data.to_string(),
276 url: image_url.url.clone(),
277 });
278 }
279 } else {
280 proc_msg.images.push(ProcessedImage {
282 mime_type: "image/jpeg".to_string(), data: String::new(),
284 url: image_url.url.clone(),
285 });
286 }
287 }
288 }
289 }
290 }
291 }
292
293 processed.push(proc_msg);
294 }
295
296 Ok(processed)
297 }
298}
299
300#[derive(Debug, Serialize)]
302struct ProcessedMessage {
303 role: String,
304 content: Option<String>,
305 images: Vec<ProcessedImage>,
306 tool_calls: Option<Vec<crate::provider::ToolCall>>,
307 tool_call_id: Option<String>,
308}
309
310#[derive(Debug, Serialize)]
311struct ProcessedImage {
312 mime_type: String,
313 data: String,
314 url: String,
315}
316
317struct JsonFilter;
319
320impl Filter for JsonFilter {
321 fn filter(&self, value: &Value, _args: &HashMap<String, Value>) -> tera::Result<Value> {
322 match serde_json::to_string(&value) {
323 Ok(json_str) => Ok(Value::String(json_str)),
324 Err(e) => Err(tera::Error::msg(format!("Failed to serialize to JSON: {}", e))),
325 }
326 }
327}
328
329struct GeminiRoleFilter;
331
332impl Filter for GeminiRoleFilter {
333 fn filter(&self, value: &Value, _args: &HashMap<String, Value>) -> tera::Result<Value> {
334 match value.as_str() {
335 Some("user") => Ok(Value::String("user".to_string())),
336 Some("assistant") => Ok(Value::String("model".to_string())),
337 Some("system") => Ok(Value::String("user".to_string())), Some(other) => Ok(Value::String(other.to_string())),
339 None => Ok(value.clone()),
340 }
341 }
342}
343
344struct SystemToUserRoleFilter;
346
347impl Filter for SystemToUserRoleFilter {
348 fn filter(&self, value: &Value, _args: &HashMap<String, Value>) -> tera::Result<Value> {
349 match value.as_str() {
350 Some("system") => Ok(Value::String("user".to_string())), Some(other) => Ok(Value::String(other.to_string())),
352 None => Ok(value.clone()),
353 }
354 }
355}
356
357struct DefaultFilter;
359
360impl Filter for DefaultFilter {
361 fn filter(&self, value: &Value, args: &HashMap<String, Value>) -> tera::Result<Value> {
362 if value.is_null() || (value.is_string() && value.as_str() == Some("")) {
363 if let Some(default_value) = args.get("value") {
364 Ok(default_value.clone())
365 } else {
366 Ok(Value::Null)
367 }
368 } else {
369 Ok(value.clone())
370 }
371 }
372}
373
374struct SelectToolCallsFilter;
376
377impl Filter for SelectToolCallsFilter {
378 fn filter(&self, value: &Value, args: &HashMap<String, Value>) -> tera::Result<Value> {
379 if let Some(array) = value.as_array() {
380 let key = args.get("key")
381 .and_then(|v| v.as_str())
382 .unwrap_or("functionCall");
383
384 let filtered: Vec<Value> = array.iter()
385 .filter(|item| {
386 item.as_object()
387 .map(|obj| obj.contains_key(key))
388 .unwrap_or(false)
389 })
390 .cloned()
391 .collect();
392
393 Ok(Value::Array(filtered))
394 } else {
395 Ok(Value::Array(vec![]))
396 }
397 }
398}
399
400struct FromJsonFilter;
402
403impl Filter for FromJsonFilter {
404 fn filter(&self, value: &Value, _args: &HashMap<String, Value>) -> tera::Result<Value> {
405 if let Some(json_str) = value.as_str() {
406 match serde_json::from_str::<JsonValue>(json_str) {
407 Ok(parsed) => {
408 match serde_json::to_value(&parsed) {
410 Ok(tera_value) => Ok(tera_value),
411 Err(e) => Err(tera::Error::msg(format!("Failed to convert to Tera value: {}", e))),
412 }
413 }
414 Err(e) => Err(tera::Error::msg(format!("Failed to parse JSON: {}", e))),
415 }
416 } else {
417 Ok(value.clone())
418 }
419 }
420}
421
422struct SelectAttrFilter;
424
425impl Filter for SelectAttrFilter {
426 fn filter(&self, value: &Value, args: &HashMap<String, Value>) -> tera::Result<Value> {
427 if let Some(array) = value.as_array() {
428 let attr_name = args.get("attr")
429 .and_then(|v| v.as_str())
430 .ok_or_else(|| tera::Error::msg("selectattr filter requires 'attr' argument"))?;
431
432 let test_value = args.get("value")
433 .ok_or_else(|| tera::Error::msg("selectattr filter requires 'value' argument"))?;
434
435 let filtered: Vec<Value> = array.iter()
436 .filter(|item| {
437 if let Some(obj) = item.as_object() {
438 if let Some(attr_value) = obj.get(attr_name) {
439 attr_value == test_value
440 } else {
441 false
442 }
443 } else {
444 false
445 }
446 })
447 .cloned()
448 .collect();
449
450 Ok(Value::Array(filtered))
451 } else {
452 Ok(Value::Array(vec![]))
453 }
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460
461 #[test]
462 fn test_json_filter() {
463 let filter = JsonFilter;
464 let value = Value::String("test".to_string());
465 let args = HashMap::new();
466
467 let result = filter.filter(&value, &args).unwrap();
468 assert_eq!(result, Value::String("\"test\"".to_string()));
469 }
470
471 #[test]
472 fn test_gemini_role_filter() {
473 let filter = GeminiRoleFilter;
474 let args = HashMap::new();
475
476 let value = Value::String("assistant".to_string());
477 let result = filter.filter(&value, &args).unwrap();
478 assert_eq!(result, Value::String("model".to_string()));
479
480 let value = Value::String("system".to_string());
481 let result = filter.filter(&value, &args).unwrap();
482 assert_eq!(result, Value::String("user".to_string()));
483 }
484
485 #[test]
486 fn test_default_filter() {
487 let filter = DefaultFilter;
488 let mut args = HashMap::new();
489 args.insert("value".to_string(), Value::String("default".to_string()));
490
491 let value = Value::Null;
492 let result = filter.filter(&value, &args).unwrap();
493 assert_eq!(result, Value::String("default".to_string()));
494
495 let value = Value::String("existing".to_string());
496 let result = filter.filter(&value, &args).unwrap();
497 assert_eq!(result, Value::String("existing".to_string()));
498 }
499}