1#![allow(clippy::collapsible_if, clippy::collapsible_else_if)]
2
3use aiproof_core::document::{Document, Kind, PromptText, Role};
4use aiproof_core::span::Span;
5use std::path::Path;
6use tree_sitter::{Node, Parser};
7
8pub fn parse(path: &Path, source: &str) -> anyhow::Result<Vec<Document>> {
9 let mut parser = Parser::new();
10 parser.set_language(&tree_sitter_python::language())?;
11
12 let tree = match parser.parse(source, None) {
13 Some(t) => t,
14 None => return Ok(Vec::new()),
15 };
16
17 let mut docs = Vec::new();
18 walk(tree.root_node(), source, path, &mut docs);
19 Ok(docs)
20}
21
22fn walk<'a>(node: Node<'a>, source: &str, path: &Path, docs: &mut Vec<Document>) {
23 if node.kind() == "call" {
24 handle_call(node, source, path, docs);
25 }
26
27 for i in 0..node.named_child_count() {
28 if let Some(child) = node.named_child(i) {
29 walk(child, source, path, docs);
30 }
31 }
32}
33
34fn handle_call(call: Node, source: &str, path: &Path, docs: &mut Vec<Document>) {
35 let Some(func) = call.child_by_field_name("function") else {
36 return;
37 };
38
39 let dotted = dotted_tail(func, source);
40 let args = call.child_by_field_name("arguments");
41
42 match dotted.as_str() {
43 s if s.ends_with("messages.create") => {
44 extract_system_kwarg(call, args, source, path, docs, "python-anthropic");
45 extract_messages_kwarg(call, args, source, path, docs);
46 if let Some(temp) = extract_temperature_kwarg(args, source) {
47 attach_temperature_to_last_n_docs(docs, temp, 2);
48 }
49 }
50 s if s.ends_with("completions.create") => {
51 extract_messages_kwarg(call, args, source, path, docs);
52 if let Some(temp) = extract_temperature_kwarg(args, source) {
53 attach_temperature_to_last_n_docs(docs, temp, 999);
54 }
55 }
56 "PromptTemplate" => {
57 extract_template_kwarg(call, args, source, path, docs);
58 }
59 "PromptTemplate.from_template" => {
60 extract_first_positional_string(call, args, source, path, docs, Role::Unknown);
61 }
62 "ChatPromptTemplate.from_messages" => {
63 extract_from_messages_list(call, args, source, path, docs);
64 }
65 "Agent" => {
66 extract_system_kwarg(call, args, source, path, docs, "python-agent");
67 }
68 _ => {}
69 }
70}
71
72fn dotted_tail(node: Node, source: &str) -> String {
74 let mut parts = Vec::new();
75 let mut current = node;
76
77 loop {
78 if current.kind() == "attribute" {
79 if let Some(attr) = current.child_by_field_name("attribute") {
80 if let Ok(name) = node_text(&attr, source) {
81 parts.push(name);
82 }
83 }
84 if let Some(obj) = current.child_by_field_name("object") {
85 current = obj;
86 continue;
87 }
88 } else if current.kind() == "identifier" {
89 if let Ok(name) = node_text(¤t, source) {
90 parts.push(name);
91 }
92 }
93 break;
94 }
95
96 parts.reverse();
97 parts.join(".")
98}
99
100fn node_text(node: &Node, source: &str) -> Result<String, ()> {
101 let start = node.start_byte();
102 let end = node.end_byte();
103 if start < end && end <= source.len() {
104 Ok(source[start..end].to_string())
105 } else {
106 Err(())
107 }
108}
109
110fn extract_system_kwarg(
111 call: Node,
112 args: Option<Node>,
113 source: &str,
114 path: &Path,
115 docs: &mut Vec<Document>,
116 _origin: &str,
117) {
118 let Some(args) = args else { return };
119
120 for i in 0..args.named_child_count() {
121 if let Some(child) = args.named_child(i) {
122 if child.kind() == "keyword_argument" {
123 if let Some(name) = child.child_by_field_name("name") {
124 if let Ok(name_text) = node_text(&name, source) {
125 if name_text == "system" {
126 if let Some(value) = child.child_by_field_name("value") {
127 if let Some((text, span)) = resolve_string_literal(value, source) {
128 docs.push(Document {
129 path: path.to_path_buf(),
130 role: Role::System,
131 source: source.to_string(),
132 prompt: PromptText {
133 text,
134 origin_span: Some(span),
135 },
136 kind: Kind::ExtractedPython {
137 call_site: Span::from_byte_range(
138 source,
139 call.start_byte()..call.end_byte(),
140 ),
141 temperature: None,
142 },
143 });
144 }
145 }
146 }
147 }
148 }
149 }
150 }
151 }
152}
153
154fn extract_messages_kwarg(
155 _call: Node,
156 args: Option<Node>,
157 source: &str,
158 path: &Path,
159 docs: &mut Vec<Document>,
160) {
161 let Some(args) = args else { return };
162
163 for i in 0..args.named_child_count() {
164 if let Some(child) = args.named_child(i) {
165 if child.kind() == "keyword_argument" {
166 if let Some(name) = child.child_by_field_name("name") {
167 if let Ok(name_text) = node_text(&name, source) {
168 if name_text == "messages" {
169 if let Some(value) = child.child_by_field_name("value") {
170 extract_messages_from_list(value, source, path, docs);
171 }
172 }
173 }
174 }
175 }
176 }
177 }
178}
179
180fn extract_temperature_kwarg(args: Option<Node>, source: &str) -> Option<f32> {
181 let args = args?;
182
183 for i in 0..args.named_child_count() {
184 if let Some(child) = args.named_child(i) {
185 if child.kind() == "keyword_argument" {
186 if let Some(name) = child.child_by_field_name("name") {
187 if let Ok(name_text) = node_text(&name, source) {
188 if name_text == "temperature" {
189 if let Some(value) = child.child_by_field_name("value") {
190 if let Ok(text) = node_text(&value, source) {
191 if let Ok(temp) = text.parse::<f32>() {
192 return Some(temp);
193 }
194 }
195 }
196 }
197 }
198 }
199 }
200 }
201 }
202 None
203}
204
205fn attach_temperature_to_last_n_docs(docs: &mut [Document], temp: f32, n: usize) {
206 let start = if docs.len() > n { docs.len() - n } else { 0 };
207 for doc in &mut docs[start..] {
208 if let Kind::ExtractedPython { temperature, .. } = &mut doc.kind {
209 *temperature = Some(temp);
210 }
211 }
212}
213
214fn extract_messages_from_list(list: Node, source: &str, path: &Path, docs: &mut Vec<Document>) {
215 if list.kind() != "list" {
216 return;
217 }
218
219 for i in 0..list.named_child_count() {
220 if let Some(child) = list.named_child(i) {
221 if child.kind() == "dictionary" {
222 extract_message_dict(child, source, path, docs);
223 }
224 }
225 }
226}
227
228fn extract_message_dict(dict: Node, source: &str, path: &Path, docs: &mut Vec<Document>) {
229 let mut role = None;
230 let mut content = None;
231
232 for i in 0..dict.named_child_count() {
233 if let Some(child) = dict.named_child(i) {
234 if child.kind() == "pair" {
235 if let Some(key) = child.child_by_field_name("key") {
236 if let Some(val) = child.child_by_field_name("value") {
237 if let Ok(key_text) = node_text(&key, source) {
238 match key_text.trim_matches('\"').trim_matches('\'') {
239 "role" => {
240 if let Ok(val_text) = node_text(&val, source) {
241 role = Some(
242 val_text
243 .trim_matches('\"')
244 .trim_matches('\'')
245 .to_string(),
246 );
247 }
248 }
249 "content" => {
250 content = resolve_string_literal(val, source);
251 }
252 _ => {}
253 }
254 }
255 }
256 }
257 }
258 }
259 }
260
261 if let (Some(role_str), Some((text, origin_span))) = (role, content) {
262 let role_enum = match role_str.as_str() {
263 "system" => Role::System,
264 "user" => Role::User,
265 "assistant" => Role::Assistant,
266 "tool" => Role::Tool,
267 _ => Role::Unknown,
268 };
269
270 docs.push(Document {
271 path: path.to_path_buf(),
272 role: role_enum,
273 source: source.to_string(),
274 prompt: PromptText {
275 text,
276 origin_span: Some(origin_span),
277 },
278 kind: Kind::ExtractedPython {
279 call_site: Span::from_byte_range(source, dict.start_byte()..dict.end_byte()),
280 temperature: None,
281 },
282 });
283 }
284}
285
286fn extract_template_kwarg(
287 _call: Node,
288 args: Option<Node>,
289 source: &str,
290 path: &Path,
291 docs: &mut Vec<Document>,
292) {
293 let Some(args) = args else { return };
294
295 for i in 0..args.named_child_count() {
296 if let Some(child) = args.named_child(i) {
297 if child.kind() == "keyword_argument" {
298 if let Some(name) = child.child_by_field_name("name") {
299 if let Ok(name_text) = node_text(&name, source) {
300 if name_text == "template" {
301 if let Some(value) = child.child_by_field_name("value") {
302 if let Some((text, span)) = resolve_string_literal(value, source) {
303 docs.push(Document {
304 path: path.to_path_buf(),
305 role: Role::Unknown,
306 source: source.to_string(),
307 prompt: PromptText {
308 text,
309 origin_span: Some(span),
310 },
311 kind: Kind::ExtractedPython {
312 call_site: Span::from_byte_range(
313 source,
314 child.start_byte()..child.end_byte(),
315 ),
316 temperature: None,
317 },
318 });
319 }
320 }
321 }
322 }
323 }
324 }
325 }
326 }
327}
328
329fn extract_first_positional_string(
330 call: Node,
331 args: Option<Node>,
332 source: &str,
333 path: &Path,
334 docs: &mut Vec<Document>,
335 role: Role,
336) {
337 let Some(args) = args else { return };
338
339 for i in 0..args.named_child_count() {
340 if let Some(child) = args.named_child(i) {
341 let is_string_arg = child.kind() == "string" || child.kind() == "argument";
342 if is_string_arg {
343 if let Some((text, span)) = resolve_string_literal(child, source) {
344 docs.push(Document {
345 path: path.to_path_buf(),
346 role,
347 source: source.to_string(),
348 prompt: PromptText {
349 text,
350 origin_span: Some(span),
351 },
352 kind: Kind::ExtractedPython {
353 call_site: Span::from_byte_range(
354 source,
355 call.start_byte()..call.end_byte(),
356 ),
357 temperature: None,
358 },
359 });
360 return; }
362 }
363 }
364 }
365}
366
367fn extract_from_messages_list(
368 _call: Node,
369 args: Option<Node>,
370 source: &str,
371 path: &Path,
372 docs: &mut Vec<Document>,
373) {
374 let Some(args) = args else { return };
375
376 for i in 0..args.named_child_count() {
377 if let Some(child) = args.named_child(i) {
378 if child.kind() == "list" {
379 for j in 0..child.named_child_count() {
380 if let Some(item) = child.named_child(j) {
381 extract_from_messages_tuple(item, source, path, docs);
382 }
383 }
384 }
385 }
386 }
387}
388
389fn extract_from_messages_tuple(tuple: Node, source: &str, path: &Path, docs: &mut Vec<Document>) {
390 if tuple.kind() != "tuple" {
391 return;
392 }
393
394 let mut role = None;
395 let mut content = None;
396
397 for i in 0..tuple.named_child_count() {
398 if let Some(child) = tuple.named_child(i) {
399 match i {
400 0 => {
401 if let Ok(text) = node_text(&child, source) {
402 role = Some(text.trim_matches('\"').trim_matches('\'').to_string());
403 }
404 }
405 1 => {
406 content = resolve_string_literal(child, source);
407 }
408 _ => {}
409 }
410 }
411 }
412
413 if let (Some(role_str), Some((text, origin_span))) = (role, content) {
414 let role_enum = match role_str.as_str() {
415 "system" => Role::System,
416 "user" => Role::User,
417 "assistant" => Role::Assistant,
418 "tool" => Role::Tool,
419 _ => Role::Unknown,
420 };
421
422 docs.push(Document {
423 path: path.to_path_buf(),
424 role: role_enum,
425 source: source.to_string(),
426 prompt: PromptText {
427 text,
428 origin_span: Some(origin_span),
429 },
430 kind: Kind::ExtractedPython {
431 call_site: Span::from_byte_range(source, tuple.start_byte()..tuple.end_byte()),
432 temperature: None,
433 },
434 });
435 }
436}
437
438fn resolve_string_literal(node: Node, source: &str) -> Option<(String, Span)> {
442 if node.kind() != "string" {
443 return None;
444 }
445
446 let start = node.start_byte();
447 let end = node.end_byte();
448 let span = Span::from_byte_range(source, start..end);
449
450 let raw_text = &source[start..end];
451
452 if raw_text.starts_with("f\"")
453 || raw_text.starts_with("f'")
454 || raw_text.starts_with("F\"")
455 || raw_text.starts_with("F'")
456 || raw_text.starts_with("rf\"")
457 || raw_text.starts_with("fr\"")
458 || raw_text.starts_with("rf'")
459 || raw_text.starts_with("fr'")
460 {
461 let text = reconstruct_fstring(node, source);
462 return Some((text, span));
463 }
464
465 if raw_text.starts_with("r\"")
466 || raw_text.starts_with("r'")
467 || raw_text.starts_with("R\"")
468 || raw_text.starts_with("R'")
469 {
470 let quote_char = if raw_text.contains("\"\"\"") || raw_text.contains("'''") {
471 &raw_text[2..5]
472 } else {
473 &raw_text[2..3]
474 };
475 let inner = extract_string_inner(raw_text, quote_char);
476 return Some((inner, span));
477 }
478
479 let quote_char = if raw_text.starts_with("\"\"\"") || raw_text.starts_with("'''") {
480 &raw_text[..3]
481 } else {
482 &raw_text[..1]
483 };
484
485 let inner = extract_string_inner(raw_text, quote_char);
486 let unescaped = unescape_string(&inner);
487 Some((unescaped, span))
488}
489
490fn extract_string_inner(raw: &str, quote: &str) -> String {
491 if let Some(stripped) = raw
492 .strip_prefix("rf")
493 .or_else(|| raw.strip_prefix("fr"))
494 .or_else(|| raw.strip_prefix("r"))
495 .or_else(|| raw.strip_prefix("f"))
496 .or_else(|| raw.strip_prefix("R"))
497 .or_else(|| raw.strip_prefix("F"))
498 {
499 let stripped = stripped.strip_prefix(quote).unwrap_or(stripped);
500 stripped.strip_suffix(quote).unwrap_or(stripped).to_string()
501 } else {
502 let stripped = raw.strip_prefix(quote).unwrap_or(raw);
503 stripped.strip_suffix(quote).unwrap_or(stripped).to_string()
504 }
505}
506
507fn unescape_string(s: &str) -> String {
508 let mut result = String::new();
509 let mut chars = s.chars().peekable();
510
511 while let Some(ch) = chars.next() {
512 if ch == '\\' {
513 match chars.peek() {
514 Some(&'n') => {
515 chars.next();
516 result.push('\n');
517 }
518 Some(&'t') => {
519 chars.next();
520 result.push('\t');
521 }
522 Some(&'r') => {
523 chars.next();
524 result.push('\r');
525 }
526 Some(&'\\') => {
527 chars.next();
528 result.push('\\');
529 }
530 Some(&'"') => {
531 chars.next();
532 result.push('"');
533 }
534 Some(&'\'') => {
535 chars.next();
536 result.push('\'');
537 }
538 _ => result.push(ch),
539 }
540 } else {
541 result.push(ch);
542 }
543 }
544
545 result
546}
547
548fn reconstruct_fstring(node: Node, source: &str) -> String {
549 let mut result = String::new();
550 let mut placeholder_index = 0;
551
552 for i in 0..node.named_child_count() {
553 if let Some(child) = node.named_child(i) {
554 match child.kind() {
555 "string_content" => {
556 if let Ok(text) = node_text(&child, source) {
557 let unescaped = unescape_string(&text);
558 result.push_str(&unescaped);
559 }
560 }
561 "interpolation" => {
562 result.push_str(&format!("{{{}}}", placeholder_index));
563 placeholder_index += 1;
564 }
565 _ => {}
566 }
567 }
568 }
569
570 if result.is_empty() {
571 let start = node.start_byte();
572 let end = node.end_byte();
573 if start < end && end <= source.len() {
574 let raw = &source[start..end];
575 let quote = if raw.contains("\"\"\"") || raw.contains("'''") {
576 &raw[..3]
577 } else {
578 &raw[2..3]
579 };
580 extract_string_inner(raw, quote)
581 } else {
582 String::new()
583 }
584 } else {
585 result
586 }
587}
588
589#[cfg(test)]
590mod tests {
591 use super::*;
592
593 fn first(src: &str) -> Document {
594 parse(Path::new("t.py"), src).unwrap().remove(0)
595 }
596
597 #[test]
598 fn anthropic_system_extracted() {
599 let src = r#"
600client.messages.create(
601 model="claude-4.7-opus",
602 system="You are a helpful assistant.",
603 messages=[{"role": "user", "content": "Hello"}],
604)
605"#;
606 let d = first(src);
607 assert_eq!(d.prompt.text, "You are a helpful assistant.");
608 assert_eq!(d.role, Role::System);
609 }
610
611 #[test]
612 fn openai_messages_extracted() {
613 let src = r#"
614openai.chat.completions.create(
615 messages=[
616 {"role": "system", "content": "Act as a tutor."},
617 {"role": "user", "content": "Teach me fractions."},
618 ],
619)
620"#;
621 let docs = parse(Path::new("t.py"), src).unwrap();
622 assert_eq!(docs.len(), 2);
623 let sys = docs.iter().find(|d| d.role == Role::System).unwrap();
624 assert_eq!(sys.prompt.text, "Act as a tutor.");
625 let user = docs.iter().find(|d| d.role == Role::User).unwrap();
626 assert_eq!(user.prompt.text, "Teach me fractions.");
627 }
628
629 #[test]
630 fn prompttemplate_from_template() {
631 let src = r#"PromptTemplate.from_template("Answer this: {q}")"#;
632 let docs = parse(Path::new("t.py"), src).unwrap();
633 assert!(
634 !docs.is_empty(),
635 "Expected at least one document, got {}",
636 docs.len()
637 );
638 let d = &docs[0];
639 assert_eq!(d.prompt.text, "Answer this: {q}");
640 }
641
642 #[test]
643 fn chatprompttemplate_from_messages() {
644 let src = r#"
645ChatPromptTemplate.from_messages([
646 ("system", "You are helpful."),
647 ("user", "Q: {q}"),
648])
649"#;
650 let docs = parse(Path::new("t.py"), src).unwrap();
651 assert_eq!(docs.len(), 2);
652 }
653
654 #[test]
655 fn fstring_becomes_positional_placeholder() {
656 let src = r#"
657client.messages.create(
658 system=f"You are {name}. Tone: {tone}.",
659 messages=[],
660)
661"#;
662 let d = first(src);
663 assert_eq!(d.prompt.text, "You are {0}. Tone: {1}.");
664 }
665
666 #[test]
667 fn dynamic_expression_skipped() {
668 let src = r#"
669client.messages.create(
670 system=SOMETHING_DYNAMIC,
671 messages=[],
672)
673"#;
674 let docs = parse(Path::new("t.py"), src).unwrap();
675 assert!(docs.is_empty());
676 }
677}