python_proto_importer/postprocess/
apply.rs

1use anyhow::{Context, Result};
2#[allow(unused_imports)]
3use prost_reflect::DescriptorPool;
4use regex::Regex;
5use std::fs;
6use std::io::Write;
7use std::path::{Component, Path, PathBuf};
8use walkdir::WalkDir;
9
10fn path_from_module(root: &Path, module_path: &str, leaf: &str) -> PathBuf {
11    let mut p = root.to_path_buf();
12    if !module_path.is_empty() {
13        for part in module_path.split('.') {
14            if !part.is_empty() {
15                p.push(part);
16            }
17        }
18    }
19    p.push(format!("{leaf}.py"));
20    p
21}
22
23fn split_module_qualname(qualified: &str) -> (String, String) {
24    if let Some(idx) = qualified.rfind('.') {
25        (
26            qualified[..idx].to_string(),
27            qualified[idx + 1..].to_string(),
28        )
29    } else {
30        (String::new(), qualified.to_string())
31    }
32}
33
34fn compute_relative_import_prefix(from_dir: &Path, to_dir: &Path) -> Option<(usize, String)> {
35    // Try canonicalize to normalize symlinks and relative segments; fall back to raw paths
36    let canonicalize_or =
37        |p: &Path| -> PathBuf { std::fs::canonicalize(p).unwrap_or_else(|_| p.to_path_buf()) };
38    let from_c = canonicalize_or(from_dir);
39    let to_c = canonicalize_or(to_dir);
40
41    let from = from_c.components().collect::<Vec<_>>();
42    let to = to_c.components().collect::<Vec<_>>();
43    let mut i = 0usize;
44    while i < from.len() && i < to.len() && from[i] == to[i] {
45        i += 1;
46    }
47    let ups = from.len().saturating_sub(i);
48    let mut remainder_parts: Vec<String> = Vec::new();
49    for comp in &to[i..] {
50        if let Component::Normal(os) = comp {
51            remainder_parts.push(os.to_string_lossy().to_string());
52        }
53    }
54    Some((
55        ups,
56        if remainder_parts.is_empty() {
57            String::new()
58        } else {
59            remainder_parts.join(".")
60        },
61    ))
62}
63
64#[allow(clippy::collapsible_if)]
65fn rewrite_lines_in_content(
66    content: &str,
67    file_dir: &Path,
68    root: &Path,
69    exclude_google: bool,
70) -> Result<(String, bool)> {
71    let mut changed = false;
72    let mut out = String::with_capacity(content.len());
73    // map of fully-qualified module -> local name to use in annotations
74    let mut module_rewrites: Vec<(String, String)> = Vec::new();
75
76    let re_import = Regex::new(
77        r"^(?P<indent>\s*)import\s+(?P<mod>[A-Za-z0-9_\.]+)\s+as\s+(?P<alias>[A-Za-z0-9_]+)\s*(?:#.*)?$",
78    )
79    .unwrap();
80    let re_from = Regex::new(r"^(?P<indent>\s*)from\s+(?P<pkg>[A-Za-z0-9_\.]+)\s+import\s+(?P<name>[A-Za-z0-9_]+)(?:\s+as\s+(?P<alias>[A-Za-z0-9_]+))?\s*(?:#.*)?$").unwrap();
81    let re_from_any =
82        Regex::new(r"^(?P<indent>\s*)from\s+(?P<pkg>[A-Za-z0-9_\.]+)\s+import\s+(?P<rest>.*)$")
83            .unwrap();
84    let re_import_simple =
85        Regex::new(r"^(?P<indent>\s*)import\s+(?P<mod>[A-Za-z0-9_\.]+)\s*(?:#.*)?$").unwrap();
86    let re_import_list = Regex::new(r"^(?P<indent>\s*)import\s+(?P<rest>.+)$").unwrap();
87
88    // State for collecting parenthesized multi-line 'from ... import (...)' blocks
89    let mut pending_from_block: Option<(String, String, String)> = None; // (indent, pkg, collected)
90
91    for line in content.lines() {
92        // Handle continuation of a parenthesized from-import block
93        if let Some((indent, pkg, mut collected)) = pending_from_block.take() {
94            collected.push('\n');
95            collected.push_str(line);
96            // Check if parentheses are balanced now
97            let opens = collected.matches('(').count();
98            let closes = collected.matches(')').count();
99            if closes < opens {
100                // Still pending
101                pending_from_block = Some((indent, pkg, collected));
102                continue;
103            }
104
105            // Process the full block
106            let processed = process_from_import_list(
107                &indent,
108                &pkg,
109                &collected,
110                file_dir,
111                root,
112                exclude_google,
113            )?;
114            out.push_str(&processed.output);
115            changed |= processed.changed;
116            continue;
117        }
118        if line.trim_start().starts_with("from .") {
119            out.push_str(line);
120            out.push('\n');
121            continue;
122        }
123        if let Some(caps) = re_import_simple.captures(line) {
124            let indent = &caps["indent"];
125            let module = &caps["mod"];
126            if !module.ends_with("_pb2") && !module.ends_with("_pb2_grpc") {
127                out.push_str(line);
128                out.push('\n');
129                continue;
130            }
131            if exclude_google && module.starts_with("google.protobuf") {
132                out.push_str(line);
133                out.push('\n');
134                continue;
135            }
136            let (module_path, leaf) = split_module_qualname(module);
137            let target = path_from_module(root, &module_path, &leaf);
138            if !target.exists() {
139                out.push_str(line);
140                out.push('\n');
141                continue;
142            }
143            if let Some((ups, remainder)) =
144                compute_relative_import_prefix(file_dir, target.parent().unwrap_or(root))
145            {
146                // ups=0 -> "." (current), ups=1 -> ".." (parent)
147                let dots = ".".repeat(ups + 1);
148                let from_pkg = if remainder.is_empty() {
149                    dots
150                } else {
151                    format!("{dots}{remainder}")
152                };
153                let new_line = format!("{indent}from {from_pkg} import {leaf}");
154                out.push_str(&new_line);
155                out.push('\n');
156                changed = true;
157                module_rewrites.push((module.to_string(), leaf.to_string()));
158                continue;
159            }
160        }
161
162        // Handle comma-separated 'import a, b as c' by splitting into tokens
163        if let Some(caps) = re_import_list.captures(line) {
164            let indent = &caps["indent"];
165            let rest = &caps["rest"]; // may contain commas and aliases
166            if rest.contains(',') {
167                let mut any_local_change = false;
168                for tok in rest.split(',') {
169                    let token = tok.trim();
170                    if token.is_empty() {
171                        continue;
172                    }
173                    // token: module[ as alias]
174                    let mut parts = token.split_whitespace().collect::<Vec<_>>();
175                    if parts.is_empty() {
176                        continue;
177                    }
178                    // reconstruct alias if provided
179                    let mut alias: Option<&str> = None;
180                    if parts.len() >= 3 && parts[parts.len() - 2] == "as" {
181                        alias = Some(parts[parts.len() - 1]);
182                        parts.truncate(parts.len() - 2);
183                    }
184                    let module = parts.join(" ");
185                    let mut rewritten = false;
186                    if (module.ends_with("_pb2") || module.ends_with("_pb2_grpc"))
187                        && !(exclude_google && module.starts_with("google.protobuf"))
188                    {
189                        let (module_path, leaf) = split_module_qualname(&module);
190                        let target = path_from_module(root, &module_path, &leaf);
191                        if target.exists() {
192                            if let Some((ups, remainder)) = compute_relative_import_prefix(
193                                file_dir,
194                                target.parent().unwrap_or(root),
195                            ) {
196                                let dots = ".".repeat(ups + 1);
197                                let from_pkg = if remainder.is_empty() {
198                                    dots
199                                } else {
200                                    format!("{dots}{remainder}")
201                                };
202                                if let Some(a) = alias {
203                                    out.push_str(&format!(
204                                        "{indent}from {from_pkg} import {leaf} as {a}\n"
205                                    ));
206                                } else {
207                                    out.push_str(&format!(
208                                        "{indent}from {from_pkg} import {leaf}\n"
209                                    ));
210                                }
211                                changed = true;
212                                any_local_change = true;
213                                rewritten = true;
214                            }
215                        }
216                    }
217                    if !rewritten {
218                        // Fallback: keep original token as a separate import line
219                        out.push_str(&format!("{indent}import {token}\n"));
220                    }
221                }
222                if any_local_change {
223                    continue;
224                }
225            }
226        }
227
228        if let Some(caps) = re_import.captures(line) {
229            let indent = &caps["indent"];
230            let module = &caps["mod"];
231            let alias = &caps["alias"];
232            if !module.ends_with("_pb2") && !module.ends_with("_pb2_grpc") {
233                out.push_str(line);
234                out.push('\n');
235                continue;
236            }
237            if exclude_google && module.starts_with("google.protobuf") {
238                out.push_str(line);
239                out.push('\n');
240                continue;
241            }
242            let (module_path, leaf) = split_module_qualname(module);
243            let target = path_from_module(root, &module_path, &leaf);
244            if !target.exists() {
245                out.push_str(line);
246                out.push('\n');
247                continue;
248            }
249            if let Some((ups, remainder)) =
250                compute_relative_import_prefix(file_dir, target.parent().unwrap_or(root))
251            {
252                // ups=0 -> "." (current), ups=1 -> ".." (parent)
253                let dots = ".".repeat(ups + 1);
254                let from_pkg = if remainder.is_empty() {
255                    dots
256                } else {
257                    format!("{dots}{remainder}")
258                };
259                let new_line = format!("{indent}from {from_pkg} import {leaf} as {alias}");
260                out.push_str(&new_line);
261                out.push('\n');
262                changed = true;
263                module_rewrites.push((module.to_string(), alias.to_string()));
264                continue;
265            }
266        }
267        // Handle single-name 'from pkg import name [as alias]'
268        if let Some(caps) = re_from.captures(line) {
269            let indent = &caps["indent"];
270            let pkg = &caps["pkg"];
271            let name = &caps["name"];
272            let alias = caps.name("alias").map(|m| m.as_str());
273            if !name.ends_with("_pb2") && !name.ends_with("_pb2_grpc") {
274                out.push_str(line);
275                out.push('\n');
276                continue;
277            }
278            if exclude_google && pkg.starts_with("google.protobuf") {
279                out.push_str(line);
280                out.push('\n');
281                continue;
282            }
283            let target = path_from_module(root, pkg, name);
284            if !target.exists() {
285                out.push_str(line);
286                out.push('\n');
287                continue;
288            }
289            if let Some((ups, remainder)) =
290                compute_relative_import_prefix(file_dir, target.parent().unwrap_or(root))
291            {
292                // ups=0 -> same level (use "." + remainder)
293                // ups=1 -> parent level (use ".." + remainder)
294                let dots = if ups == 0 {
295                    ".".to_string()
296                } else {
297                    ".".repeat(ups + 1)
298                };
299                let from_pkg = if remainder.is_empty() {
300                    dots
301                } else {
302                    format!("{dots}{remainder}")
303                };
304                let new_line = if let Some(a) = alias {
305                    format!("{indent}from {from_pkg} import {name} as {a}")
306                } else {
307                    format!("{indent}from {from_pkg} import {name}")
308                };
309                out.push_str(&new_line);
310                out.push('\n');
311                changed = true;
312                // fully-qualified = pkg.name
313                let fq = if pkg.is_empty() {
314                    name.to_string()
315                } else {
316                    format!("{pkg}.{name}")
317                };
318                let local = alias
319                    .map(|s| s.to_string())
320                    .unwrap_or_else(|| name.to_string());
321                module_rewrites.push((fq, local));
322                continue;
323            }
324        }
325
326        // Handle 'from pkg import a, b as c' (single-line) or start of parenthesized block
327        if let Some(caps) = re_from_any.captures(line) {
328            let indent = caps["indent"].to_string();
329            let pkg = caps["pkg"].to_string();
330            let rest = caps["rest"].trim();
331            if rest.starts_with('(') && !rest.contains(')') {
332                // Begin collecting a multi-line parenthesized block
333                pending_from_block = Some((indent, pkg, line.to_string()));
334                continue;
335            }
336            if rest.contains(',') || rest.starts_with('(') {
337                // Process possibly parenthesized single-line list
338                let processed =
339                    process_from_import_list(&indent, &pkg, line, file_dir, root, exclude_google)?;
340                out.push_str(&processed.output);
341                changed |= processed.changed;
342                continue;
343            }
344        }
345        out.push_str(line);
346        out.push('\n');
347    }
348    // After rewriting imports, fix fully-qualified references in annotations
349    if !module_rewrites.is_empty() {
350        for (from_mod, to_name) in module_rewrites.iter() {
351            // replace occurrences like "from_mod.*" to "to_name.*"
352            let pattern = regex::Regex::new(&format!(r"\b{}\.", regex::escape(from_mod))).unwrap();
353            let replaced = pattern.replace_all(&out, format!("{}.", to_name));
354            let new_str = replaced.into_owned();
355            if new_str != out {
356                changed = true;
357                out = new_str;
358            }
359        }
360    }
361
362    Ok((out, changed))
363}
364
365struct FromImportProcessResult {
366    output: String,
367    changed: bool,
368}
369
370fn process_from_import_list(
371    indent: &str,
372    pkg: &str,
373    full_line_or_block: &str,
374    file_dir: &Path,
375    root: &Path,
376    exclude_google: bool,
377) -> Result<FromImportProcessResult> {
378    // Extract everything after 'from <pkg> import'
379    let after_import = full_line_or_block
380        .split_once(" import ")
381        .map(|(_, s)| s.trim())
382        .unwrap_or_else(|| full_line_or_block.trim());
383
384    // Remove wrapping parentheses and trailing comment lines
385    let mut inner = after_import.trim();
386    if inner.starts_with('(') {
387        // Remove the first '(' and the matching last ')'
388        // For robustness, just trim leading '(' and trailing ')' and whitespace/commas
389        inner = inner.trim_start_matches('(');
390        inner = inner.trim_end();
391        if inner.ends_with(')') {
392            inner = &inner[..inner.len() - 1];
393        }
394    }
395
396    // Split by commas across potential multi-lines
397    let mut tokens: Vec<String> = Vec::new();
398    for raw in inner.lines() {
399        let no_comment = match raw.find('#') {
400            Some(idx) => &raw[..idx],
401            None => raw,
402        };
403        for part in no_comment.split(',') {
404            let t = part.trim();
405            if !t.is_empty() {
406                tokens.push(t.to_string());
407            }
408        }
409    }
410
411    if tokens.is_empty() {
412        return Ok(FromImportProcessResult {
413            output: format!("{}from {} import {}\n", indent, pkg, inner.trim()),
414            changed: false,
415        });
416    }
417
418    // Partition into rewritable and others
419    let mut rewrite_items: Vec<(String, Option<String>)> = Vec::new();
420    let mut keep_items: Vec<String> = Vec::new();
421    for tok in tokens {
422        // token: name [as alias]
423        let mut name = tok.as_str();
424        let mut alias: Option<String> = None;
425        if let Some(pos) = tok.rfind(" as ") {
426            name = tok[..pos].trim();
427            alias = Some(tok[pos + 4..].trim().to_string());
428        }
429        if (name.ends_with("_pb2") || name.ends_with("_pb2_grpc"))
430            && !(exclude_google && pkg.starts_with("google.protobuf"))
431        {
432            // Check target exists
433            let target = path_from_module(root, pkg, name);
434            if target.exists() {
435                rewrite_items.push((name.to_string(), alias));
436                continue;
437            }
438        }
439        keep_items.push(tok);
440    }
441
442    if rewrite_items.is_empty() {
443        // No change
444        return Ok(FromImportProcessResult {
445            output: format!("{}{}\n", indent, full_line_or_block.trim()),
446            changed: false,
447        });
448    }
449
450    // Compute relative from-pkg using any one item's target (they share pkg)
451    let any_name = &rewrite_items[0].0;
452    let target = path_from_module(root, pkg, any_name);
453    let (ups, remainder) =
454        compute_relative_import_prefix(file_dir, target.parent().unwrap_or(root))
455            .unwrap_or((0, String::new()));
456    let dots = if ups == 0 {
457        ".".to_string()
458    } else {
459        ".".repeat(ups + 1)
460    };
461    let from_pkg = if remainder.is_empty() {
462        dots
463    } else {
464        format!("{dots}{remainder}")
465    };
466
467    // Build output lines: first the rewritten relative import
468    let mut output = String::new();
469    let list = rewrite_items
470        .into_iter()
471        .map(|(n, a)| match a {
472            Some(x) => format!("{} as {}", n, x),
473            None => n,
474        })
475        .collect::<Vec<_>>()
476        .join(", ");
477    output.push_str(&format!("{}from {} import {}\n", indent, from_pkg, list));
478
479    // Keep the remaining items via original absolute import if any
480    if !keep_items.is_empty() {
481        let keep_list = keep_items.join(", ");
482        output.push_str(&format!("{}from {} import {}\n", indent, pkg, keep_list));
483    }
484
485    Ok(FromImportProcessResult {
486        output,
487        changed: true,
488    })
489}
490
491#[allow(dead_code)]
492pub fn apply_rewrites_in_tree(
493    root: &Path,
494    exclude_google: bool,
495    module_suffixes: &[String],
496    allowed_basenames: Option<&std::collections::HashSet<String>>,
497) -> Result<usize> {
498    let mut modified = 0usize;
499    for entry in WalkDir::new(root).into_iter().filter_map(Result::ok) {
500        let p = entry.path();
501        if p.is_file() {
502            let rel = p.strip_prefix(root).unwrap_or(p).to_string_lossy();
503            let mut matched = false;
504            for s in module_suffixes {
505                if (s.ends_with(".py") || s.ends_with(".pyi")) && rel.ends_with(s) {
506                    matched = true;
507                    break;
508                }
509            }
510            if !matched {
511                continue;
512            }
513            let content = fs::read_to_string(p).with_context(|| format!("read {}", p.display()))?;
514            // Pre-filter: if allowed_basenames (from FDS) are provided,
515            // skip files that don't contain any target basename
516            if matches!(
517                allowed_basenames,
518                Some(allowed) if !allowed.iter().any(|b| content.contains(b))
519            ) {
520                continue;
521            }
522            let (new_content, changed) = rewrite_lines_in_content(
523                &content,
524                p.parent().unwrap_or(root),
525                root,
526                exclude_google,
527            )?;
528            if changed {
529                let mut f = fs::OpenOptions::new()
530                    .write(true)
531                    .truncate(true)
532                    .open(p)
533                    .with_context(|| format!("open {} for write", p.display()))?;
534                f.write_all(new_content.as_bytes())
535                    .with_context(|| format!("write {}", p.display()))?;
536                modified += 1;
537            }
538        }
539    }
540    Ok(modified)
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546    use std::fs;
547    use tempfile::tempdir;
548
549    #[test]
550    fn compute_prefix_basic() {
551        let _root = Path::new("/");
552        let from = Path::new("/a/b");
553        let to = Path::new("/a/c/d");
554        let (ups, rem) = compute_relative_import_prefix(from, to).unwrap();
555        assert_eq!(ups, 1);
556        assert_eq!(rem, "c.d");
557    }
558
559    #[test]
560    fn compute_prefix_same_level() {
561        // Test sibling directories: billing/ and order/ under generated/
562        let from = Path::new("generated/billing");
563        let to = Path::new("generated/order");
564        let (ups, rem) = compute_relative_import_prefix(from, to).unwrap();
565        assert_eq!(ups, 1); // Go up one level to parent, then down to sibling
566        assert_eq!(rem, "order");
567    }
568
569    #[test]
570    fn compute_prefix_with_relative_segments() {
571        // from: ./a/./b, to: a/c/../c/d -> expect up 1 and remainder c.d
572        let tmp = tempdir().unwrap();
573        let root = tmp.path();
574        std::fs::create_dir_all(root.join("a/b")).unwrap();
575        std::fs::create_dir_all(root.join("a/c/d")).unwrap();
576
577        let from = root.join("./a/./b");
578        let to = root.join("a/c/../c/d");
579        let (ups, rem) = compute_relative_import_prefix(&from, &to).unwrap();
580        assert_eq!(ups, 1);
581        assert_eq!(rem, "c.d");
582    }
583
584    #[cfg(unix)]
585    #[test]
586    fn compute_prefix_with_symlink() {
587        use std::os::unix::fs::symlink;
588        let tmp = tempdir().unwrap();
589        let root = tmp.path();
590        std::fs::create_dir_all(root.join("real/order")).unwrap();
591        std::fs::create_dir_all(root.join("real/billing")).unwrap();
592        // symlink 'gen' -> 'real'
593        symlink(root.join("real"), root.join("gen")).unwrap();
594
595        let from = root.join("gen/billing");
596        let to = root.join("real/order");
597        let (ups, rem) = compute_relative_import_prefix(&from, &to).unwrap();
598        // After canonicalize, common prefix is root/real, expect up 1 and remainder order
599        assert_eq!(ups, 1);
600        assert_eq!(rem, "order");
601    }
602
603    #[test]
604    fn rewrite_import_alias() {
605        let dir = tempdir().unwrap();
606        let root = dir.path();
607        // target module at root/a_pb2.py
608        fs::write(root.join("a_pb2.py"), "# stub").unwrap();
609        // file under sub/needs.py
610        let sub = root.join("sub");
611        fs::create_dir_all(&sub).unwrap();
612        let content = "import a_pb2 as a__pb2\n";
613        let (out, changed) = rewrite_lines_in_content(content, &sub, root, false).unwrap();
614        assert!(changed);
615        assert_eq!(out, "from .. import a_pb2 as a__pb2\n");
616    }
617
618    #[test]
619    fn rewrite_pyi_simple_import() {
620        let dir = tempdir().unwrap();
621        let root = dir.path();
622        fs::write(root.join("a_pb2.py"), "# stub").unwrap();
623        let sub = root.join("pkg");
624        fs::create_dir_all(&sub).unwrap();
625        let content = "import a_pb2\n";
626        let (out, changed) = rewrite_lines_in_content(content, &sub, root, false).unwrap();
627        assert!(changed);
628        assert_eq!(out, "from .. import a_pb2\n");
629    }
630
631    #[test]
632    fn skip_google_protobuf() {
633        let dir = tempdir().unwrap();
634        let root = dir.path();
635        // no need to create files; should skip due to exclude_google
636        let content = "import google.protobuf.timestamp_pb2 as timestamp__pb2\n";
637        let (out, changed) = rewrite_lines_in_content(content, root, root, true).unwrap();
638        assert!(!changed);
639        assert_eq!(out, content);
640    }
641
642    #[test]
643    fn apply_rewrites_suffix_filter() {
644        let dir = tempdir().unwrap();
645        let root = dir.path();
646        // create structure
647        fs::create_dir_all(root.join("x")).unwrap();
648        fs::write(root.join("a_pb2.py"), "# a\n").unwrap();
649        fs::write(root.join("x/b_pb2.py"), "import a_pb2 as a__pb2\n").unwrap();
650        fs::write(root.join("c.py"), "import a_pb2 as a__pb2\n").unwrap();
651        let modified = apply_rewrites_in_tree(root, false, &["_pb2.py".into()], None).unwrap();
652        // only x/b_pb2.py should be modified
653        assert_eq!(modified, 1);
654        let b = fs::read_to_string(root.join("x/b_pb2.py")).unwrap();
655        assert_eq!(b, "from .. import a_pb2 as a__pb2\n");
656        let c = fs::read_to_string(root.join("c.py")).unwrap();
657        assert_eq!(c, "import a_pb2 as a__pb2\n");
658    }
659
660    #[test]
661    fn rewrite_from_multi_items_single_line() {
662        let dir = tempdir().unwrap();
663        let root = dir.path();
664        // structure: pkg/a_pb2.py and pkg/b_pb2_grpc.py
665        std::fs::create_dir_all(root.join("pkg")).unwrap();
666        fs::write(root.join("pkg/a_pb2.py"), "# a").unwrap();
667        fs::write(root.join("pkg/b_pb2_grpc.py"), "# b").unwrap();
668        let file_dir = root.join("pkg");
669        let content = "from pkg import a_pb2, b_pb2_grpc as bgrpc\n";
670        let (out, changed) = rewrite_lines_in_content(content, &file_dir, root, false).unwrap();
671        assert!(changed);
672        assert_eq!(out.trim_end(), "from . import a_pb2, b_pb2_grpc as bgrpc");
673    }
674
675    #[test]
676    fn rewrite_from_parenthesized_multi_line() {
677        let dir = tempdir().unwrap();
678        let root = dir.path();
679        std::fs::create_dir_all(root.join("pkg")).unwrap();
680        fs::write(root.join("pkg/a_pb2.py"), "# a").unwrap();
681        fs::write(root.join("pkg/b_pb2.py"), "# b").unwrap();
682        let file_dir = root.join("pkg");
683        let content = "from pkg import (\n    a_pb2,\n    b_pb2 as bb,\n)\n";
684        let (out, changed) = rewrite_lines_in_content(content, &file_dir, root, false).unwrap();
685        assert!(changed);
686        assert_eq!(out.trim_end(), "from . import a_pb2, b_pb2 as bb");
687    }
688
689    #[test]
690    fn rewrite_import_list_into_multiple_lines() {
691        let dir = tempdir().unwrap();
692        let root = dir.path();
693        std::fs::create_dir_all(root.join("pkg/sub")).unwrap();
694        fs::write(root.join("pkg/a_pb2.py"), "# a").unwrap();
695        fs::write(root.join("pkg/sub/b_pb2.py"), "# b").unwrap();
696        let file_dir = root; // importing at project root
697        let content = "import pkg.a_pb2, pkg.sub.b_pb2 as bb, json\n";
698        let (out, changed) = rewrite_lines_in_content(content, file_dir, root, false).unwrap();
699        assert!(changed);
700        // Should produce two from-import lines and keep 'json' as import
701        let lines: Vec<_> = out.lines().collect();
702        assert_eq!(lines.len(), 3);
703        assert!(
704            lines[0].starts_with("from .pkg import a_pb2")
705                || lines[1].starts_with("from .pkg import a_pb2")
706        );
707        assert!(
708            lines
709                .iter()
710                .any(|l| l.starts_with("from .pkg.sub import b_pb2 as bb"))
711        );
712        assert!(lines.contains(&"import json"));
713    }
714
715    #[test]
716    fn keep_google_protobuf_in_multi() {
717        let dir = tempdir().unwrap();
718        let root = dir.path();
719        std::fs::create_dir_all(root.join("pkg")).unwrap();
720        fs::write(root.join("pkg/a_pb2.py"), "# a").unwrap();
721        let file_dir = root.join("pkg");
722        let content = "from google.protobuf import timestamp_pb2, duration_pb2\nfrom pkg import a_pb2, timestamp_pb2\n";
723        let (out, changed) = rewrite_lines_in_content(content, &file_dir, root, true).unwrap();
724        assert!(changed); // a_pb2 should change but google protobuf kept
725        assert!(out.contains("from . import a_pb2"));
726        assert!(out.contains("from google.protobuf import timestamp_pb2, duration_pb2"));
727    }
728
729    #[test]
730    fn rewrite_from_sibling_directory() {
731        // Test the actual scenario: billing/ importing from order/
732        let dir = tempdir().unwrap();
733        let root = dir.path();
734
735        // Create structure: generated/billing/ and generated/order/
736        fs::create_dir_all(root.join("billing")).unwrap();
737        fs::create_dir_all(root.join("order")).unwrap();
738        fs::write(root.join("order/order_pb2.py"), "# order module\n").unwrap();
739
740        let billing_content = "from order import order_pb2 as order_dot_order__pb2\n";
741        fs::write(root.join("billing/billing_pb2.py"), billing_content).unwrap();
742
743        let modified = apply_rewrites_in_tree(root, false, &["_pb2.py".into()], None).unwrap();
744        assert_eq!(modified, 1);
745
746        let billing = fs::read_to_string(root.join("billing/billing_pb2.py")).unwrap();
747        // Should be sibling import: from ..order import order_pb2 (up one level, then down)
748        assert_eq!(
749            billing,
750            "from ..order import order_pb2 as order_dot_order__pb2\n"
751        );
752    }
753}