Skip to main content

rustify_ml/
input.rs

1use std::io::{self, Read};
2use std::path::Path;
3
4use anyhow::{Context, Result, anyhow};
5use tracing::{info, warn};
6
7use crate::utils::InputSource;
8
9pub fn load_input(
10    file: Option<&Path>,
11    snippet: bool,
12    git: Option<&str>,
13    git_path: Option<&Path>,
14) -> Result<InputSource> {
15    if snippet {
16        let mut buffer = String::new();
17        io::stdin()
18            .read_to_string(&mut buffer)
19            .context("failed to read Python snippet from stdin")?;
20        info!(chars = buffer.len(), "loaded snippet from stdin");
21        return Ok(InputSource::Snippet(buffer));
22    }
23
24    if let Some(path) = file {
25        let code = std::fs::read_to_string(path)
26            .with_context(|| format!("failed to read Python file at {}", path.display()))?;
27        info!(path = %path.display(), bytes = code.len(), "loaded file input");
28        return Ok(InputSource::File {
29            path: path.to_path_buf(),
30            code,
31        });
32    }
33
34    if let Some(repo) = git {
35        let git_path =
36            git_path.ok_or_else(|| anyhow!("--git-path is required when using --git"))?;
37
38        let tmpdir = tempfile::tempdir().context("failed to create temp dir for git clone")?;
39        let repo_dir = tmpdir.path().join("repo");
40        info!(repo, path = %git_path.display(), "cloning git repo (shallow if supported)");
41        let mut fo = git2::FetchOptions::new();
42        fo.download_tags(git2::AutotagOption::None);
43        fo.update_fetchhead(true);
44        let mut co = git2::build::RepoBuilder::new();
45        co.fetch_options(fo);
46        co.clone(repo, &repo_dir)
47            .with_context(|| format!("failed to clone repo {repo}"))?;
48
49        let target_path = repo_dir.join(git_path);
50        if !target_path.exists() {
51            warn!(path = %target_path.display(), "git path not found in repo");
52            return Err(anyhow!("git path not found: {}", target_path.display()));
53        }
54        let code = std::fs::read_to_string(&target_path).with_context(|| {
55            format!(
56                "failed to read file {} from git repo",
57                target_path.display()
58            )
59        })?;
60        info!(path = %target_path.display(), bytes = code.len(), "loaded git input");
61        return Ok(InputSource::Git {
62            repo: repo.to_string(),
63            path: target_path,
64            code,
65        });
66    }
67
68    Err(anyhow::anyhow!(
69        "no input provided; pass --file, --snippet, or --git"
70    ))
71}