codex_convert_proxy/convert/request/
context.rs1use std::collections::HashSet;
22use std::str::FromStr;
23
24use crate::types::response_api::Tool;
25
26#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
28pub enum ToolPriority {
29 PreferDefined,
31 PreferSearched,
33 #[default]
35 Merge,
36}
37
38impl FromStr for ToolPriority {
39 type Err = std::convert::Infallible;
40
41 fn from_str(s: &str) -> Result<Self, Self::Err> {
42 Ok(match s.to_lowercase().as_str() {
43 "prefer_defined" | "prefer-defined" => ToolPriority::PreferDefined,
44 "prefer_searched" | "prefer-searched" => ToolPriority::PreferSearched,
45 "merge" | "combined" => ToolPriority::Merge,
46 _ => {
47 tracing::warn!(
48 "[TOOL_SEARCH] unknown tool_priority '{}', defaulting to 'merge'",
49 s
50 );
51 ToolPriority::Merge
52 }
53 })
54 }
55}
56
57#[derive(Debug, Clone)]
64pub struct ToolSearchContext {
65 pending_calls: HashSet<String>,
67 searched_tools: Vec<Tool>,
69 predefined_tools: Vec<Tool>,
71 resolved_tools: Vec<Tool>,
73 priority: ToolPriority,
75 finalized: bool,
77}
78
79impl ToolSearchContext {
80 pub fn new(priority: ToolPriority) -> Self {
82 Self {
83 pending_calls: HashSet::new(),
84 searched_tools: Vec::new(),
85 predefined_tools: Vec::new(),
86 resolved_tools: Vec::new(),
87 priority,
88 finalized: false,
89 }
90 }
91
92 pub fn register_search_call(&mut self, call_id: &str) {
94 self.pending_calls.insert(call_id.to_string());
95 tracing::debug!(
96 "[TOOL_SEARCH] registered tool_search_call: id={}",
97 call_id
98 );
99 }
100
101 pub fn is_registered_search(&self, call_id: &str) -> bool {
103 self.pending_calls.contains(call_id)
104 }
105
106 pub fn set_predefined_tools(&mut self, tools: Vec<Tool>) {
110 self.predefined_tools = tools;
111 tracing::debug!(
112 "[TOOL_SEARCH] set predefined tools: count={}",
113 self.predefined_tools.len()
114 );
115 }
116
117 pub fn add_searched_tools(&mut self, tools: Vec<Tool>, call_id: &str) {
121 if !self.pending_calls.contains(call_id) {
122 tracing::warn!(
123 "[TOOL_SEARCH] tool_search_output has unrecognized call_id '{}', \
124 it may not have a corresponding tool_search_call",
125 call_id
126 );
127 }
128
129 let count = tools.len();
130 self.searched_tools.extend(tools);
131 tracing::debug!(
132 "[TOOL_SEARCH] added {} tools from tool_search_output: call_id={}, total_searched={}",
133 count,
134 call_id,
135 self.searched_tools.len()
136 );
137 }
138
139 pub fn complete_search(&mut self, call_id: &str) {
144 self.pending_calls.remove(call_id);
145 tracing::debug!(
146 "[TOOL_SEARCH] completed tool_search_call: call_id={}",
147 call_id
148 );
149 }
150
151 #[must_use]
159 pub fn finalize(mut self) -> Vec<Tool> {
160 if self.finalized {
161 return std::mem::take(&mut self.resolved_tools);
162 }
163
164 tracing::debug!(
165 "[TOOL_SEARCH] finalizing: predefined={}, searched={}, priority={:?}",
166 self.predefined_tools.len(),
167 self.searched_tools.len(),
168 self.priority
169 );
170
171 let result = match self.priority {
172 ToolPriority::PreferDefined => {
173 if !self.searched_tools.is_empty() {
174 tracing::info!(
175 "[TOOL_SEARCH] PreferDefined: ignoring {} searched tools",
176 self.searched_tools.len()
177 );
178 }
179 std::mem::take(&mut self.predefined_tools)
180 }
181 ToolPriority::PreferSearched => {
182 if !self.predefined_tools.is_empty() {
183 tracing::info!(
184 "[TOOL_SEARCH] PreferSearched: ignoring {} predefined tools",
185 self.predefined_tools.len()
186 );
187 }
188 std::mem::take(&mut self.searched_tools)
189 }
190 ToolPriority::Merge => {
191 merge_tools_map(&self.predefined_tools, &self.searched_tools)
192 }
193 };
194
195 self.finalized = true;
196 self.resolved_tools = result.clone();
197 tracing::debug!(
198 "[TOOL_SEARCH] resolved tools: count={}",
199 result.len()
200 );
201
202 result
203 }
204
205 pub fn resolved_tools(&self) -> &[Tool] {
207 &self.resolved_tools
208 }
209
210 pub fn priority(&self) -> ToolPriority {
212 self.priority
213 }
214
215 pub fn has_pending_searches(&self) -> bool {
217 !self.pending_calls.is_empty()
218 }
219
220 pub fn predefined_count(&self) -> usize {
222 self.predefined_tools.len()
223 }
224
225 pub fn searched_count(&self) -> usize {
227 self.searched_tools.len()
228 }
229}
230
231pub(crate) fn merge_tools_map(first: &[Tool], second: &[Tool]) -> Vec<Tool> {
241 use std::collections::HashMap;
242
243 let mut name_to_tool: HashMap<String, &Tool> = HashMap::new();
244
245 for tool in first {
247 if let Some(name) = &tool.name {
248 name_to_tool.insert(name.clone(), tool);
249 }
250 }
251
252 for tool in second {
254 if let Some(name) = &tool.name {
255 name_to_tool.insert(name.clone(), tool);
256 }
257 }
258
259 let mut result: Vec<Tool> = Vec::new();
262 let mut seen: HashSet<String> = HashSet::new();
263
264 for tool in first {
266 if let Some(name) = &tool.name
267 && !seen.contains(name)
268 {
269 result.push(tool.clone());
270 seen.insert(name.clone());
271 }
272 }
273
274 for tool in second {
276 if let Some(name) = &tool.name {
277 let first_had = first.iter().any(|t| t.name.as_ref() == Some(name));
279 if !first_had || !seen.contains(name) {
280 if !seen.contains(name) {
283 result.push(tool.clone());
284 seen.insert(name.clone());
285 }
286 }
287 }
288 }
289
290 result
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use crate::types::response_api::{Tool, ToolType};
297
298 fn make_tool(name: &str) -> Tool {
299 Tool {
300 tool_type: ToolType::Function,
301 name: Some(name.to_string()),
302 description: None,
303 parameters: None,
304 strict: None,
305 extra: std::collections::HashMap::new(),
306 }
307 }
308
309 #[test]
310 fn test_prefer_defined_keeps_predefined() {
311 let mut ctx = ToolSearchContext::new(ToolPriority::PreferDefined);
312 ctx.set_predefined_tools(vec![make_tool("tool_a"), make_tool("tool_b")]);
313 ctx.add_searched_tools(vec![make_tool("tool_c")], "call_1");
314
315 let resolved = ctx.finalize().clone();
316 assert_eq!(resolved.len(), 2);
317 assert!(resolved.iter().any(|t| t.name.as_ref().unwrap() == "tool_a"));
318 assert!(resolved.iter().any(|t| t.name.as_ref().unwrap() == "tool_b"));
319 }
320
321 #[test]
322 fn test_prefer_searched_keeps_searched() {
323 let mut ctx = ToolSearchContext::new(ToolPriority::PreferSearched);
324 ctx.set_predefined_tools(vec![make_tool("tool_a"), make_tool("tool_b")]);
325 ctx.add_searched_tools(vec![make_tool("tool_c"), make_tool("tool_d")], "call_1");
326
327 let resolved = ctx.finalize().clone();
328 assert_eq!(resolved.len(), 2);
329 assert!(resolved.iter().any(|t| t.name.as_ref().unwrap() == "tool_c"));
330 assert!(resolved.iter().any(|t| t.name.as_ref().unwrap() == "tool_d"));
331 }
332
333 #[test]
334 fn test_merge_combines_all_unique() {
335 let mut ctx = ToolSearchContext::new(ToolPriority::Merge);
336 ctx.set_predefined_tools(vec![make_tool("tool_a"), make_tool("tool_b")]);
337 ctx.add_searched_tools(vec![make_tool("tool_c"), make_tool("tool_d")], "call_1");
338
339 let resolved = ctx.finalize().clone();
340 assert_eq!(resolved.len(), 4);
341 }
342
343 #[test]
344 fn test_merge_searched_overrides_on_conflict() {
345 let mut ctx = ToolSearchContext::new(ToolPriority::Merge);
346 ctx.set_predefined_tools(vec![make_tool("tool_a"), make_tool("tool_b")]);
347 ctx.add_searched_tools(vec![make_tool("tool_b"), make_tool("tool_c")], "call_1");
348
349 let resolved = ctx.finalize().clone();
350 assert_eq!(resolved.len(), 3);
351
352 let _tool_b = resolved.iter().find(|t| t.name.as_ref().unwrap() == "tool_b");
354 assert!(_tool_b.is_some());
355 }
356
357 #[test]
358 fn test_register_search_call() {
359 let mut ctx = ToolSearchContext::new(ToolPriority::Merge);
360 ctx.register_search_call("call_123");
361
362 assert!(ctx.is_registered_search("call_123"));
363 assert!(!ctx.is_registered_search("call_456"));
364 }
365
366 #[test]
367 fn test_priority_from_str() {
368 assert_eq!("prefer_defined".parse::<ToolPriority>(), Ok(ToolPriority::PreferDefined));
369 assert_eq!("prefer-searched".parse::<ToolPriority>(), Ok(ToolPriority::PreferSearched));
370 assert_eq!("merge".parse::<ToolPriority>(), Ok(ToolPriority::Merge));
371 assert_eq!("unknown".parse::<ToolPriority>(), Ok(ToolPriority::Merge)); }
373
374 #[test]
375 fn test_finalize_idempotent() {
376 let mut ctx = ToolSearchContext::new(ToolPriority::Merge);
377 ctx.set_predefined_tools(vec![make_tool("tool_a")]);
378 ctx.add_searched_tools(vec![make_tool("tool_b")], "call_1");
379
380 let mut ctx2 = ToolSearchContext::new(ToolPriority::Merge);
383 ctx2.set_predefined_tools(vec![make_tool("tool_a")]);
384 ctx2.add_searched_tools(vec![make_tool("tool_b")], "call_1");
385
386 let first = ctx.finalize();
387 assert_eq!(first.len(), 2);
388
389 let second = ctx2.finalize();
392 assert_eq!(second.len(), 2);
393 }
394}