adk_core/
instruction_template.rs1use crate::{AdkError, InvocationContext, Result};
2use regex::Regex;
3use std::sync::OnceLock;
4
5static PLACEHOLDER_REGEX: OnceLock<Regex> = OnceLock::new();
9
10fn get_placeholder_regex() -> &'static Regex {
11 PLACEHOLDER_REGEX.get_or_init(|| {
12 Regex::new(r"\{[a-zA-Z_][a-zA-Z0-9_:.]*\??\}").expect("Invalid regex pattern")
14 })
15}
16
17fn is_identifier(s: &str) -> bool {
20 if s.is_empty() {
21 return false;
22 }
23
24 let mut chars = s.chars();
25 let first = chars.next().unwrap();
26
27 if !first.is_alphabetic() && first != '_' {
28 return false;
29 }
30
31 chars.all(|c| c.is_alphanumeric() || c == '_')
32}
33
34fn is_valid_state_name(var_name: &str) -> bool {
37 let parts: Vec<&str> = var_name.split(':').collect();
38
39 match parts.len() {
40 1 => is_identifier(var_name),
41 2 => {
42 let prefix = format!("{}:", parts[0]);
43 let valid_prefixes = ["app:", "user:", "temp:"];
44 valid_prefixes.contains(&prefix.as_str()) && is_identifier(parts[1])
45 }
46 _ => false,
47 }
48}
49
50async fn replace_match(ctx: &dyn InvocationContext, match_str: &str) -> Result<String> {
53 let var_name = match_str.trim_matches(|c| c == '{' || c == '}').trim();
55
56 let (var_name, optional) =
58 if let Some(name) = var_name.strip_suffix('?') { (name, true) } else { (var_name, false) };
59
60 if let Some(file_name) = var_name.strip_prefix("artifact.") {
62 let artifacts = ctx
63 .artifacts()
64 .ok_or_else(|| AdkError::Agent("Artifact service is not initialized".to_string()))?;
65
66 match artifacts.load(file_name).await {
67 Ok(part) => {
68 if let Some(text) = part.text() {
70 return Ok(text.to_string());
71 }
72 Ok(String::new())
73 }
74 Err(e) => {
75 if optional {
76 Ok(String::new())
78 } else {
79 Err(AdkError::Agent(format!("Failed to load artifact {}: {}", file_name, e)))
80 }
81 }
82 }
83 } else if is_valid_state_name(var_name) {
84 let state_value = ctx.session().state().get(var_name);
86
87 match state_value {
88 Some(value) => {
89 if let Some(s) = value.as_str() {
91 Ok(s.to_string())
92 } else {
93 Ok(format!("{}", value))
94 }
95 }
96 None => {
97 if optional {
98 Ok(String::new())
99 } else {
100 Err(AdkError::Agent(format!("State variable '{}' not found", var_name)))
101 }
102 }
103 }
104 } else {
105 Ok(match_str.to_string())
107 }
108}
109
110pub async fn inject_session_state(ctx: &dyn InvocationContext, template: &str) -> Result<String> {
133 let regex = get_placeholder_regex();
134 let mut result = String::with_capacity((template.len() as f32 * 1.2) as usize);
136 let mut last_end = 0;
137
138 for captures in regex.find_iter(template) {
139 let match_range = captures.range();
140
141 result.push_str(&template[last_end..match_range.start]);
143
144 let match_str = captures.as_str();
146 let replacement = replace_match(ctx, match_str).await?;
147 result.push_str(&replacement);
148
149 last_end = match_range.end;
150 }
151
152 result.push_str(&template[last_end..]);
154
155 Ok(result)
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[test]
163 fn test_is_identifier() {
164 assert!(is_identifier("valid_name"));
165 assert!(is_identifier("_private"));
166 assert!(is_identifier("name123"));
167 assert!(!is_identifier("123invalid"));
168 assert!(!is_identifier(""));
169 assert!(!is_identifier("with-dash"));
170 }
171
172 #[test]
173 fn test_is_valid_state_name() {
174 assert!(is_valid_state_name("valid_var"));
175 assert!(is_valid_state_name("app:config"));
176 assert!(is_valid_state_name("user:preference"));
177 assert!(is_valid_state_name("temp:data"));
178 assert!(!is_valid_state_name("invalid:prefix"));
179 assert!(!is_valid_state_name("app:invalid-name"));
180 assert!(!is_valid_state_name("too:many:parts"));
181 }
182}