lc/
template_processor.rs

1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use serde_json::Value as JsonValue;
4use std::collections::HashMap;
5use tera::{Context as TeraContext, Filter, Tera, Value};
6
7use crate::provider::{ChatRequest, Message, MessageContent, ContentPart};
8
9/// Template processor for handling request/response transformations
10#[derive(Clone)]
11pub struct TemplateProcessor {
12    tera: Tera,
13}
14
15/// Endpoint-specific templates with model pattern support
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct EndpointTemplates {
18    /// Default template for all models
19    #[serde(default)]
20    pub template: Option<TemplateConfig>,
21    
22    /// Model-specific templates (exact match)
23    #[serde(default)]
24    pub model_templates: HashMap<String, TemplateConfig>,
25    
26    /// Model pattern templates (regex match)
27    #[serde(default)]
28    pub model_template_patterns: HashMap<String, TemplateConfig>,
29}
30
31/// Template configuration for request/response
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct TemplateConfig {
34    /// Request transformation template
35    pub request: Option<String>,
36    /// Response parsing template
37    pub response: Option<String>,
38    /// Streaming response parsing template
39    pub stream_response: Option<String>,
40}
41
42/// Model-specific endpoint templates (for backward compatibility)
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ModelEndpointTemplates {
45    #[serde(default)]
46    pub chat: Option<TemplateConfig>,
47    #[serde(default)]
48    pub images: Option<TemplateConfig>,
49    #[serde(default)]
50    pub embeddings: Option<TemplateConfig>,
51}
52
53impl EndpointTemplates {
54    /// Get template for a specific model, checking patterns and defaults
55    #[allow(dead_code)]
56    pub fn get_template_for_model(&self, model_name: &str, template_type: &str) -> Option<String> {
57        // First check exact match
58        if let Some(template) = self.model_templates.get(model_name) {
59            return match template_type {
60                "request" => template.request.clone(),
61                "response" => template.response.clone(),
62                "stream_response" => template.stream_response.clone(),
63                _ => None,
64            };
65        }
66        
67        // Then check regex patterns
68        for (pattern, template) in &self.model_template_patterns {
69            if let Ok(re) = regex::Regex::new(pattern) {
70                if re.is_match(model_name) {
71                    return match template_type {
72                        "request" => template.request.clone(),
73                        "response" => template.response.clone(),
74                        "stream_response" => template.stream_response.clone(),
75                        _ => None,
76                    };
77                }
78            }
79        }
80        
81        // Finally fall back to default template
82        if let Some(template) = &self.template {
83            return match template_type {
84                "request" => template.request.clone(),
85                "response" => template.response.clone(),
86                "stream_response" => template.stream_response.clone(),
87                _ => None,
88            };
89        }
90        
91        None
92    }
93}
94
95impl TemplateProcessor {
96    /// Create a new template processor
97    pub fn new() -> Result<Self> {
98        let mut tera = Tera::default();
99        
100        // Register custom filters
101        tera.register_filter("json", JsonFilter);
102        tera.register_filter("gemini_role", GeminiRoleFilter);
103        tera.register_filter("system_to_user_role", SystemToUserRoleFilter);
104        tera.register_filter("default", DefaultFilter);
105        tera.register_filter("select_tool_calls", SelectToolCallsFilter);
106        tera.register_filter("from_json", FromJsonFilter);
107        tera.register_filter("selectattr", SelectAttrFilter);
108        
109        Ok(Self { tera })
110    }
111
112    /// Render a template directly
113    #[allow(dead_code)]
114    pub fn render_template(&mut self, name: &str, template: &str, context: &TeraContext) -> Result<String> {
115        self.tera.add_raw_template(name, template)?;
116        Ok(self.tera.render(name, context)?)
117    }
118
119    /// Process a chat request using the provided template
120    pub fn process_request(
121        &mut self,
122        request: &ChatRequest,
123        template: &str,
124        provider_vars: &HashMap<String, String>,
125    ) -> Result<JsonValue> {
126        // Add template to Tera
127        self.tera
128            .add_raw_template("request", template)
129            .context("Failed to parse request template")?;
130
131        // Build context from ChatRequest
132        let mut context = TeraContext::new();
133        
134        // Add basic fields
135        context.insert("model", &request.model);
136        context.insert("max_tokens", &request.max_tokens);
137        context.insert("temperature", &request.temperature);
138        context.insert("stream", &request.stream);
139        context.insert("tools", &request.tools);
140        
141        // Process messages into a format suitable for templates
142        let processed_messages = self.process_messages(&request.messages)?;
143        context.insert("messages", &processed_messages);
144        
145        // Extract system prompt if present
146        if let Some(system_msg) = request.messages.iter().find(|m| m.role == "system") {
147            if let Some(content) = system_msg.get_text_content() {
148                context.insert("system_prompt", content);
149            }
150        }
151        
152        // Add provider-specific variables
153        for (key, value) in provider_vars {
154            context.insert(key, value);
155        }
156
157        // Render template
158        let rendered = self.tera
159            .render("request", &context)
160            .context("Failed to render request template")?;
161
162        // Parse as JSON to validate
163        let json_value: JsonValue = serde_json::from_str(&rendered)
164            .context("Template did not produce valid JSON")?;
165
166        Ok(json_value)
167    }
168
169    /// Process an image generation request using the provided template
170    pub fn process_image_request(
171        &mut self,
172        request: &crate::provider::ImageGenerationRequest,
173        template: &str,
174        provider_vars: &HashMap<String, String>,
175    ) -> Result<JsonValue> {
176        // Add template to Tera
177        self.tera
178            .add_raw_template("image_request", template)
179            .context("Failed to parse image request template")?;
180
181        // Build context from ImageGenerationRequest
182        let mut context = TeraContext::new();
183        
184        // Add basic fields
185        context.insert("prompt", &request.prompt);
186        context.insert("model", &request.model);
187        context.insert("n", &request.n);
188        context.insert("size", &request.size);
189        context.insert("quality", &request.quality);
190        context.insert("style", &request.style);
191        context.insert("response_format", &request.response_format);
192        
193        // Add provider-specific variables
194        for (key, value) in provider_vars {
195            context.insert(key, value);
196        }
197
198        // Render template
199        let rendered = self.tera
200            .render("image_request", &context)
201            .context("Failed to render image request template")?;
202
203        // Parse as JSON to validate
204        let json_value: JsonValue = serde_json::from_str(&rendered)
205            .context("Image template did not produce valid JSON")?;
206
207        Ok(json_value)
208    }
209
210    /// Process a response using the provided template
211    pub fn process_response(
212        &mut self,
213        response: &JsonValue,
214        template: &str,
215    ) -> Result<JsonValue> {
216        // Add template to Tera
217        self.tera
218            .add_raw_template("response", template)
219            .context("Failed to parse response template")?;
220
221        // Build context from response
222        let context = TeraContext::from_serialize(response)
223            .context("Failed to serialize response to context")?;
224
225        // Render template
226        let rendered = self.tera
227            .render("response", &context)
228            .context("Failed to render response template")?;
229
230        // Parse as JSON
231        let json_value: JsonValue = serde_json::from_str(&rendered)
232            .context("Response template did not produce valid JSON")?;
233
234        Ok(json_value)
235    }
236
237    /// Process messages into a format suitable for templates
238    fn process_messages(&self, messages: &[Message]) -> Result<Vec<ProcessedMessage>> {
239        let mut processed = Vec::new();
240
241        for message in messages {
242            let mut proc_msg = ProcessedMessage {
243                role: message.role.clone(),
244                content: None,
245                images: Vec::new(),
246                tool_calls: message.tool_calls.clone(),
247                tool_call_id: message.tool_call_id.clone(),
248            };
249
250            match &message.content_type {
251                MessageContent::Text { content } => {
252                    proc_msg.content = content.clone();
253                }
254                MessageContent::Multimodal { content } => {
255                    for part in content {
256                        match part {
257                            ContentPart::Text { text } => {
258                                proc_msg.content = Some(text.clone());
259                            }
260                            ContentPart::ImageUrl { image_url } => {
261                                // Extract base64 data and mime type from data URL
262                                if let Some(data_url) = image_url.url.strip_prefix("data:") {
263                                    if let Some(comma_pos) = data_url.find(',') {
264                                        let header = &data_url[..comma_pos];
265                                        let data = &data_url[comma_pos + 1..];
266                                        
267                                        let mime_type = if let Some(semi_pos) = header.find(';') {
268                                            header[..semi_pos].to_string()
269                                        } else {
270                                            header.to_string()
271                                        };
272                                        
273                                        proc_msg.images.push(ProcessedImage {
274                                            mime_type,
275                                            data: data.to_string(),
276                                            url: image_url.url.clone(),
277                                        });
278                                    }
279                                } else {
280                                    // Regular URL
281                                    proc_msg.images.push(ProcessedImage {
282                                        mime_type: "image/jpeg".to_string(), // Default
283                                        data: String::new(),
284                                        url: image_url.url.clone(),
285                                    });
286                                }
287                            }
288                        }
289                    }
290                }
291            }
292
293            processed.push(proc_msg);
294        }
295
296        Ok(processed)
297    }
298}
299
300/// Processed message format for templates
301#[derive(Debug, Serialize)]
302struct ProcessedMessage {
303    role: String,
304    content: Option<String>,
305    images: Vec<ProcessedImage>,
306    tool_calls: Option<Vec<crate::provider::ToolCall>>,
307    tool_call_id: Option<String>,
308}
309
310#[derive(Debug, Serialize)]
311struct ProcessedImage {
312    mime_type: String,
313    data: String,
314    url: String,
315}
316
317/// Custom filter to convert values to JSON
318struct JsonFilter;
319
320impl Filter for JsonFilter {
321    fn filter(&self, value: &Value, _args: &HashMap<String, Value>) -> tera::Result<Value> {
322        match serde_json::to_string(&value) {
323            Ok(json_str) => Ok(Value::String(json_str)),
324            Err(e) => Err(tera::Error::msg(format!("Failed to serialize to JSON: {}", e))),
325        }
326    }
327}
328
329/// Filter to convert OpenAI roles to Gemini roles
330struct GeminiRoleFilter;
331
332impl Filter for GeminiRoleFilter {
333    fn filter(&self, value: &Value, _args: &HashMap<String, Value>) -> tera::Result<Value> {
334        match value.as_str() {
335            Some("user") => Ok(Value::String("user".to_string())),
336            Some("assistant") => Ok(Value::String("model".to_string())),
337            Some("system") => Ok(Value::String("user".to_string())), // Gemini handles system as user
338            Some(other) => Ok(Value::String(other.to_string())),
339            None => Ok(value.clone()),
340        }
341    }
342}
343
344/// Filter to convert system roles to user roles (for providers that don't support system roles)
345struct SystemToUserRoleFilter;
346
347impl Filter for SystemToUserRoleFilter {
348    fn filter(&self, value: &Value, _args: &HashMap<String, Value>) -> tera::Result<Value> {
349        match value.as_str() {
350            Some("system") => Ok(Value::String("user".to_string())), // Convert system to user
351            Some(other) => Ok(Value::String(other.to_string())),
352            None => Ok(value.clone()),
353        }
354    }
355}
356
357/// Filter to provide default values
358struct DefaultFilter;
359
360impl Filter for DefaultFilter {
361    fn filter(&self, value: &Value, args: &HashMap<String, Value>) -> tera::Result<Value> {
362        if value.is_null() || (value.is_string() && value.as_str() == Some("")) {
363            if let Some(default_value) = args.get("value") {
364                Ok(default_value.clone())
365            } else {
366                Ok(Value::Null)
367            }
368        } else {
369            Ok(value.clone())
370        }
371    }
372}
373
374/// Filter to select items with tool calls
375struct SelectToolCallsFilter;
376
377impl Filter for SelectToolCallsFilter {
378    fn filter(&self, value: &Value, args: &HashMap<String, Value>) -> tera::Result<Value> {
379        if let Some(array) = value.as_array() {
380            let key = args.get("key")
381                .and_then(|v| v.as_str())
382                .unwrap_or("functionCall");
383                
384            let filtered: Vec<Value> = array.iter()
385                .filter(|item| {
386                    item.as_object()
387                        .map(|obj| obj.contains_key(key))
388                        .unwrap_or(false)
389                })
390                .cloned()
391                .collect();
392                
393            Ok(Value::Array(filtered))
394        } else {
395            Ok(Value::Array(vec![]))
396        }
397    }
398}
399
400/// Filter to parse JSON strings
401struct FromJsonFilter;
402
403impl Filter for FromJsonFilter {
404    fn filter(&self, value: &Value, _args: &HashMap<String, Value>) -> tera::Result<Value> {
405        if let Some(json_str) = value.as_str() {
406            match serde_json::from_str::<JsonValue>(json_str) {
407                Ok(parsed) => {
408                    // Convert JsonValue to Tera Value
409                    match serde_json::to_value(&parsed) {
410                        Ok(tera_value) => Ok(tera_value),
411                        Err(e) => Err(tera::Error::msg(format!("Failed to convert to Tera value: {}", e))),
412                    }
413                }
414                Err(e) => Err(tera::Error::msg(format!("Failed to parse JSON: {}", e))),
415            }
416        } else {
417            Ok(value.clone())
418        }
419    }
420}
421
422/// Filter to select items by attribute value (simplified version of Jinja2's selectattr)
423struct SelectAttrFilter;
424
425impl Filter for SelectAttrFilter {
426    fn filter(&self, value: &Value, args: &HashMap<String, Value>) -> tera::Result<Value> {
427        if let Some(array) = value.as_array() {
428            let attr_name = args.get("attr")
429                .and_then(|v| v.as_str())
430                .ok_or_else(|| tera::Error::msg("selectattr filter requires 'attr' argument"))?;
431            
432            let test_value = args.get("value")
433                .ok_or_else(|| tera::Error::msg("selectattr filter requires 'value' argument"))?;
434                
435            let filtered: Vec<Value> = array.iter()
436                .filter(|item| {
437                    if let Some(obj) = item.as_object() {
438                        if let Some(attr_value) = obj.get(attr_name) {
439                            attr_value == test_value
440                        } else {
441                            false
442                        }
443                    } else {
444                        false
445                    }
446                })
447                .cloned()
448                .collect();
449                
450            Ok(Value::Array(filtered))
451        } else {
452            Ok(Value::Array(vec![]))
453        }
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    #[test]
462    fn test_json_filter() {
463        let filter = JsonFilter;
464        let value = Value::String("test".to_string());
465        let args = HashMap::new();
466        
467        let result = filter.filter(&value, &args).unwrap();
468        assert_eq!(result, Value::String("\"test\"".to_string()));
469    }
470
471    #[test]
472    fn test_gemini_role_filter() {
473        let filter = GeminiRoleFilter;
474        let args = HashMap::new();
475        
476        let value = Value::String("assistant".to_string());
477        let result = filter.filter(&value, &args).unwrap();
478        assert_eq!(result, Value::String("model".to_string()));
479        
480        let value = Value::String("system".to_string());
481        let result = filter.filter(&value, &args).unwrap();
482        assert_eq!(result, Value::String("user".to_string()));
483    }
484
485    #[test]
486    fn test_default_filter() {
487        let filter = DefaultFilter;
488        let mut args = HashMap::new();
489        args.insert("value".to_string(), Value::String("default".to_string()));
490        
491        let value = Value::Null;
492        let result = filter.filter(&value, &args).unwrap();
493        assert_eq!(result, Value::String("default".to_string()));
494        
495        let value = Value::String("existing".to_string());
496        let result = filter.filter(&value, &args).unwrap();
497        assert_eq!(result, Value::String("existing".to_string()));
498    }
499}