Skip to main content

aiproof_parse/
python_extract.rs

1#![allow(clippy::collapsible_if, clippy::collapsible_else_if)]
2
3use aiproof_core::document::{Document, Kind, PromptText, Role};
4use aiproof_core::span::Span;
5use std::path::Path;
6use tree_sitter::{Node, Parser};
7
8pub fn parse(path: &Path, source: &str) -> anyhow::Result<Vec<Document>> {
9    let mut parser = Parser::new();
10    parser.set_language(&tree_sitter_python::language())?;
11
12    let tree = match parser.parse(source, None) {
13        Some(t) => t,
14        None => return Ok(Vec::new()),
15    };
16
17    let mut docs = Vec::new();
18    walk(tree.root_node(), source, path, &mut docs);
19    Ok(docs)
20}
21
22fn walk<'a>(node: Node<'a>, source: &str, path: &Path, docs: &mut Vec<Document>) {
23    if node.kind() == "call" {
24        handle_call(node, source, path, docs);
25    }
26
27    for i in 0..node.named_child_count() {
28        if let Some(child) = node.named_child(i) {
29            walk(child, source, path, docs);
30        }
31    }
32}
33
34fn handle_call(call: Node, source: &str, path: &Path, docs: &mut Vec<Document>) {
35    let Some(func) = call.child_by_field_name("function") else {
36        return;
37    };
38
39    let dotted = dotted_tail(func, source);
40    let args = call.child_by_field_name("arguments");
41
42    match dotted.as_str() {
43        s if s.ends_with("messages.create") => {
44            extract_system_kwarg(call, args, source, path, docs, "python-anthropic");
45            extract_messages_kwarg(call, args, source, path, docs);
46            if let Some(temp) = extract_temperature_kwarg(args, source) {
47                attach_temperature_to_last_n_docs(docs, temp, 2);
48            }
49        }
50        s if s.ends_with("completions.create") => {
51            extract_messages_kwarg(call, args, source, path, docs);
52            if let Some(temp) = extract_temperature_kwarg(args, source) {
53                attach_temperature_to_last_n_docs(docs, temp, 999);
54            }
55        }
56        "PromptTemplate" => {
57            extract_template_kwarg(call, args, source, path, docs);
58        }
59        "PromptTemplate.from_template" => {
60            extract_first_positional_string(call, args, source, path, docs, Role::Unknown);
61        }
62        "ChatPromptTemplate.from_messages" => {
63            extract_from_messages_list(call, args, source, path, docs);
64        }
65        "Agent" => {
66            extract_system_kwarg(call, args, source, path, docs, "python-agent");
67        }
68        _ => {}
69    }
70}
71
72/// Returns the dotted tail of an attribute expression, e.g. "messages.create".
73fn dotted_tail(node: Node, source: &str) -> String {
74    let mut parts = Vec::new();
75    let mut current = node;
76
77    loop {
78        if current.kind() == "attribute" {
79            if let Some(attr) = current.child_by_field_name("attribute") {
80                if let Ok(name) = node_text(&attr, source) {
81                    parts.push(name);
82                }
83            }
84            if let Some(obj) = current.child_by_field_name("object") {
85                current = obj;
86                continue;
87            }
88        } else if current.kind() == "identifier" {
89            if let Ok(name) = node_text(&current, source) {
90                parts.push(name);
91            }
92        }
93        break;
94    }
95
96    parts.reverse();
97    parts.join(".")
98}
99
100fn node_text(node: &Node, source: &str) -> Result<String, ()> {
101    let start = node.start_byte();
102    let end = node.end_byte();
103    if start < end && end <= source.len() {
104        Ok(source[start..end].to_string())
105    } else {
106        Err(())
107    }
108}
109
110fn extract_system_kwarg(
111    call: Node,
112    args: Option<Node>,
113    source: &str,
114    path: &Path,
115    docs: &mut Vec<Document>,
116    _origin: &str,
117) {
118    let Some(args) = args else { return };
119
120    for i in 0..args.named_child_count() {
121        if let Some(child) = args.named_child(i) {
122            if child.kind() == "keyword_argument" {
123                if let Some(name) = child.child_by_field_name("name") {
124                    if let Ok(name_text) = node_text(&name, source) {
125                        if name_text == "system" {
126                            if let Some(value) = child.child_by_field_name("value") {
127                                if let Some((text, span)) = resolve_string_literal(value, source) {
128                                    docs.push(Document {
129                                        path: path.to_path_buf(),
130                                        role: Role::System,
131                                        source: source.to_string(),
132                                        prompt: PromptText {
133                                            text,
134                                            origin_span: Some(span),
135                                        },
136                                        kind: Kind::ExtractedPython {
137                                            call_site: Span::from_byte_range(
138                                                source,
139                                                call.start_byte()..call.end_byte(),
140                                            ),
141                                            temperature: None,
142                                        },
143                                    });
144                                }
145                            }
146                        }
147                    }
148                }
149            }
150        }
151    }
152}
153
154fn extract_messages_kwarg(
155    _call: Node,
156    args: Option<Node>,
157    source: &str,
158    path: &Path,
159    docs: &mut Vec<Document>,
160) {
161    let Some(args) = args else { return };
162
163    for i in 0..args.named_child_count() {
164        if let Some(child) = args.named_child(i) {
165            if child.kind() == "keyword_argument" {
166                if let Some(name) = child.child_by_field_name("name") {
167                    if let Ok(name_text) = node_text(&name, source) {
168                        if name_text == "messages" {
169                            if let Some(value) = child.child_by_field_name("value") {
170                                extract_messages_from_list(value, source, path, docs);
171                            }
172                        }
173                    }
174                }
175            }
176        }
177    }
178}
179
180fn extract_temperature_kwarg(args: Option<Node>, source: &str) -> Option<f32> {
181    let args = args?;
182
183    for i in 0..args.named_child_count() {
184        if let Some(child) = args.named_child(i) {
185            if child.kind() == "keyword_argument" {
186                if let Some(name) = child.child_by_field_name("name") {
187                    if let Ok(name_text) = node_text(&name, source) {
188                        if name_text == "temperature" {
189                            if let Some(value) = child.child_by_field_name("value") {
190                                if let Ok(text) = node_text(&value, source) {
191                                    if let Ok(temp) = text.parse::<f32>() {
192                                        return Some(temp);
193                                    }
194                                }
195                            }
196                        }
197                    }
198                }
199            }
200        }
201    }
202    None
203}
204
205fn attach_temperature_to_last_n_docs(docs: &mut [Document], temp: f32, n: usize) {
206    let start = if docs.len() > n { docs.len() - n } else { 0 };
207    for doc in &mut docs[start..] {
208        if let Kind::ExtractedPython { temperature, .. } = &mut doc.kind {
209            *temperature = Some(temp);
210        }
211    }
212}
213
214fn extract_messages_from_list(list: Node, source: &str, path: &Path, docs: &mut Vec<Document>) {
215    if list.kind() != "list" {
216        return;
217    }
218
219    for i in 0..list.named_child_count() {
220        if let Some(child) = list.named_child(i) {
221            if child.kind() == "dictionary" {
222                extract_message_dict(child, source, path, docs);
223            }
224        }
225    }
226}
227
228fn extract_message_dict(dict: Node, source: &str, path: &Path, docs: &mut Vec<Document>) {
229    let mut role = None;
230    let mut content = None;
231
232    for i in 0..dict.named_child_count() {
233        if let Some(child) = dict.named_child(i) {
234            if child.kind() == "pair" {
235                if let Some(key) = child.child_by_field_name("key") {
236                    if let Some(val) = child.child_by_field_name("value") {
237                        if let Ok(key_text) = node_text(&key, source) {
238                            match key_text.trim_matches('\"').trim_matches('\'') {
239                                "role" => {
240                                    if let Ok(val_text) = node_text(&val, source) {
241                                        role = Some(
242                                            val_text
243                                                .trim_matches('\"')
244                                                .trim_matches('\'')
245                                                .to_string(),
246                                        );
247                                    }
248                                }
249                                "content" => {
250                                    content = resolve_string_literal(val, source);
251                                }
252                                _ => {}
253                            }
254                        }
255                    }
256                }
257            }
258        }
259    }
260
261    if let (Some(role_str), Some((text, origin_span))) = (role, content) {
262        let role_enum = match role_str.as_str() {
263            "system" => Role::System,
264            "user" => Role::User,
265            "assistant" => Role::Assistant,
266            "tool" => Role::Tool,
267            _ => Role::Unknown,
268        };
269
270        docs.push(Document {
271            path: path.to_path_buf(),
272            role: role_enum,
273            source: source.to_string(),
274            prompt: PromptText {
275                text,
276                origin_span: Some(origin_span),
277            },
278            kind: Kind::ExtractedPython {
279                call_site: Span::from_byte_range(source, dict.start_byte()..dict.end_byte()),
280                temperature: None,
281            },
282        });
283    }
284}
285
286fn extract_template_kwarg(
287    _call: Node,
288    args: Option<Node>,
289    source: &str,
290    path: &Path,
291    docs: &mut Vec<Document>,
292) {
293    let Some(args) = args else { return };
294
295    for i in 0..args.named_child_count() {
296        if let Some(child) = args.named_child(i) {
297            if child.kind() == "keyword_argument" {
298                if let Some(name) = child.child_by_field_name("name") {
299                    if let Ok(name_text) = node_text(&name, source) {
300                        if name_text == "template" {
301                            if let Some(value) = child.child_by_field_name("value") {
302                                if let Some((text, span)) = resolve_string_literal(value, source) {
303                                    docs.push(Document {
304                                        path: path.to_path_buf(),
305                                        role: Role::Unknown,
306                                        source: source.to_string(),
307                                        prompt: PromptText {
308                                            text,
309                                            origin_span: Some(span),
310                                        },
311                                        kind: Kind::ExtractedPython {
312                                            call_site: Span::from_byte_range(
313                                                source,
314                                                child.start_byte()..child.end_byte(),
315                                            ),
316                                            temperature: None,
317                                        },
318                                    });
319                                }
320                            }
321                        }
322                    }
323                }
324            }
325        }
326    }
327}
328
329fn extract_first_positional_string(
330    call: Node,
331    args: Option<Node>,
332    source: &str,
333    path: &Path,
334    docs: &mut Vec<Document>,
335    role: Role,
336) {
337    let Some(args) = args else { return };
338
339    for i in 0..args.named_child_count() {
340        if let Some(child) = args.named_child(i) {
341            let is_string_arg = child.kind() == "string" || child.kind() == "argument";
342            if is_string_arg {
343                if let Some((text, span)) = resolve_string_literal(child, source) {
344                    docs.push(Document {
345                        path: path.to_path_buf(),
346                        role,
347                        source: source.to_string(),
348                        prompt: PromptText {
349                            text,
350                            origin_span: Some(span),
351                        },
352                        kind: Kind::ExtractedPython {
353                            call_site: Span::from_byte_range(
354                                source,
355                                call.start_byte()..call.end_byte(),
356                            ),
357                            temperature: None,
358                        },
359                    });
360                    return; // Only first positional
361                }
362            }
363        }
364    }
365}
366
367fn extract_from_messages_list(
368    _call: Node,
369    args: Option<Node>,
370    source: &str,
371    path: &Path,
372    docs: &mut Vec<Document>,
373) {
374    let Some(args) = args else { return };
375
376    for i in 0..args.named_child_count() {
377        if let Some(child) = args.named_child(i) {
378            if child.kind() == "list" {
379                for j in 0..child.named_child_count() {
380                    if let Some(item) = child.named_child(j) {
381                        extract_from_messages_tuple(item, source, path, docs);
382                    }
383                }
384            }
385        }
386    }
387}
388
389fn extract_from_messages_tuple(tuple: Node, source: &str, path: &Path, docs: &mut Vec<Document>) {
390    if tuple.kind() != "tuple" {
391        return;
392    }
393
394    let mut role = None;
395    let mut content = None;
396
397    for i in 0..tuple.named_child_count() {
398        if let Some(child) = tuple.named_child(i) {
399            match i {
400                0 => {
401                    if let Ok(text) = node_text(&child, source) {
402                        role = Some(text.trim_matches('\"').trim_matches('\'').to_string());
403                    }
404                }
405                1 => {
406                    content = resolve_string_literal(child, source);
407                }
408                _ => {}
409            }
410        }
411    }
412
413    if let (Some(role_str), Some((text, origin_span))) = (role, content) {
414        let role_enum = match role_str.as_str() {
415            "system" => Role::System,
416            "user" => Role::User,
417            "assistant" => Role::Assistant,
418            "tool" => Role::Tool,
419            _ => Role::Unknown,
420        };
421
422        docs.push(Document {
423            path: path.to_path_buf(),
424            role: role_enum,
425            source: source.to_string(),
426            prompt: PromptText {
427                text,
428                origin_span: Some(origin_span),
429            },
430            kind: Kind::ExtractedPython {
431                call_site: Span::from_byte_range(source, tuple.start_byte()..tuple.end_byte()),
432                temperature: None,
433            },
434        });
435    }
436}
437
438/// Resolve a string literal node to (text, origin_span).
439/// Handles: plain strings, f-strings (with placeholder substitution), raw strings.
440/// Returns None for dynamic expressions, names, etc.
441fn resolve_string_literal(node: Node, source: &str) -> Option<(String, Span)> {
442    if node.kind() != "string" {
443        return None;
444    }
445
446    let start = node.start_byte();
447    let end = node.end_byte();
448    let span = Span::from_byte_range(source, start..end);
449
450    let raw_text = &source[start..end];
451
452    if raw_text.starts_with("f\"")
453        || raw_text.starts_with("f'")
454        || raw_text.starts_with("F\"")
455        || raw_text.starts_with("F'")
456        || raw_text.starts_with("rf\"")
457        || raw_text.starts_with("fr\"")
458        || raw_text.starts_with("rf'")
459        || raw_text.starts_with("fr'")
460    {
461        let text = reconstruct_fstring(node, source);
462        return Some((text, span));
463    }
464
465    if raw_text.starts_with("r\"")
466        || raw_text.starts_with("r'")
467        || raw_text.starts_with("R\"")
468        || raw_text.starts_with("R'")
469    {
470        let quote_char = if raw_text.contains("\"\"\"") || raw_text.contains("'''") {
471            &raw_text[2..5]
472        } else {
473            &raw_text[2..3]
474        };
475        let inner = extract_string_inner(raw_text, quote_char);
476        return Some((inner, span));
477    }
478
479    let quote_char = if raw_text.starts_with("\"\"\"") || raw_text.starts_with("'''") {
480        &raw_text[..3]
481    } else {
482        &raw_text[..1]
483    };
484
485    let inner = extract_string_inner(raw_text, quote_char);
486    let unescaped = unescape_string(&inner);
487    Some((unescaped, span))
488}
489
490fn extract_string_inner(raw: &str, quote: &str) -> String {
491    if let Some(stripped) = raw
492        .strip_prefix("rf")
493        .or_else(|| raw.strip_prefix("fr"))
494        .or_else(|| raw.strip_prefix("r"))
495        .or_else(|| raw.strip_prefix("f"))
496        .or_else(|| raw.strip_prefix("R"))
497        .or_else(|| raw.strip_prefix("F"))
498    {
499        let stripped = stripped.strip_prefix(quote).unwrap_or(stripped);
500        stripped.strip_suffix(quote).unwrap_or(stripped).to_string()
501    } else {
502        let stripped = raw.strip_prefix(quote).unwrap_or(raw);
503        stripped.strip_suffix(quote).unwrap_or(stripped).to_string()
504    }
505}
506
507fn unescape_string(s: &str) -> String {
508    let mut result = String::new();
509    let mut chars = s.chars().peekable();
510
511    while let Some(ch) = chars.next() {
512        if ch == '\\' {
513            match chars.peek() {
514                Some(&'n') => {
515                    chars.next();
516                    result.push('\n');
517                }
518                Some(&'t') => {
519                    chars.next();
520                    result.push('\t');
521                }
522                Some(&'r') => {
523                    chars.next();
524                    result.push('\r');
525                }
526                Some(&'\\') => {
527                    chars.next();
528                    result.push('\\');
529                }
530                Some(&'"') => {
531                    chars.next();
532                    result.push('"');
533                }
534                Some(&'\'') => {
535                    chars.next();
536                    result.push('\'');
537                }
538                _ => result.push(ch),
539            }
540        } else {
541            result.push(ch);
542        }
543    }
544
545    result
546}
547
548fn reconstruct_fstring(node: Node, source: &str) -> String {
549    let mut result = String::new();
550    let mut placeholder_index = 0;
551
552    for i in 0..node.named_child_count() {
553        if let Some(child) = node.named_child(i) {
554            match child.kind() {
555                "string_content" => {
556                    if let Ok(text) = node_text(&child, source) {
557                        let unescaped = unescape_string(&text);
558                        result.push_str(&unescaped);
559                    }
560                }
561                "interpolation" => {
562                    result.push_str(&format!("{{{}}}", placeholder_index));
563                    placeholder_index += 1;
564                }
565                _ => {}
566            }
567        }
568    }
569
570    if result.is_empty() {
571        let start = node.start_byte();
572        let end = node.end_byte();
573        if start < end && end <= source.len() {
574            let raw = &source[start..end];
575            let quote = if raw.contains("\"\"\"") || raw.contains("'''") {
576                &raw[..3]
577            } else {
578                &raw[2..3]
579            };
580            extract_string_inner(raw, quote)
581        } else {
582            String::new()
583        }
584    } else {
585        result
586    }
587}
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592
593    fn first(src: &str) -> Document {
594        parse(Path::new("t.py"), src).unwrap().remove(0)
595    }
596
597    #[test]
598    fn anthropic_system_extracted() {
599        let src = r#"
600client.messages.create(
601    model="claude-4.7-opus",
602    system="You are a helpful assistant.",
603    messages=[{"role": "user", "content": "Hello"}],
604)
605"#;
606        let d = first(src);
607        assert_eq!(d.prompt.text, "You are a helpful assistant.");
608        assert_eq!(d.role, Role::System);
609    }
610
611    #[test]
612    fn openai_messages_extracted() {
613        let src = r#"
614openai.chat.completions.create(
615    messages=[
616        {"role": "system", "content": "Act as a tutor."},
617        {"role": "user", "content": "Teach me fractions."},
618    ],
619)
620"#;
621        let docs = parse(Path::new("t.py"), src).unwrap();
622        assert_eq!(docs.len(), 2);
623        let sys = docs.iter().find(|d| d.role == Role::System).unwrap();
624        assert_eq!(sys.prompt.text, "Act as a tutor.");
625        let user = docs.iter().find(|d| d.role == Role::User).unwrap();
626        assert_eq!(user.prompt.text, "Teach me fractions.");
627    }
628
629    #[test]
630    fn prompttemplate_from_template() {
631        let src = r#"PromptTemplate.from_template("Answer this: {q}")"#;
632        let docs = parse(Path::new("t.py"), src).unwrap();
633        assert!(
634            !docs.is_empty(),
635            "Expected at least one document, got {}",
636            docs.len()
637        );
638        let d = &docs[0];
639        assert_eq!(d.prompt.text, "Answer this: {q}");
640    }
641
642    #[test]
643    fn chatprompttemplate_from_messages() {
644        let src = r#"
645ChatPromptTemplate.from_messages([
646    ("system", "You are helpful."),
647    ("user", "Q: {q}"),
648])
649"#;
650        let docs = parse(Path::new("t.py"), src).unwrap();
651        assert_eq!(docs.len(), 2);
652    }
653
654    #[test]
655    fn fstring_becomes_positional_placeholder() {
656        let src = r#"
657client.messages.create(
658    system=f"You are {name}. Tone: {tone}.",
659    messages=[],
660)
661"#;
662        let d = first(src);
663        assert_eq!(d.prompt.text, "You are {0}. Tone: {1}.");
664    }
665
666    #[test]
667    fn dynamic_expression_skipped() {
668        let src = r#"
669client.messages.create(
670    system=SOMETHING_DYNAMIC,
671    messages=[],
672)
673"#;
674        let docs = parse(Path::new("t.py"), src).unwrap();
675        assert!(docs.is_empty());
676    }
677}