Skip to main content

fraiseql_cli/commands/extract/
mod.rs

1//! `fraiseql extract` — Parse annotated source files to schema.json
2//!
3//! Extracts FraiseQL type and query definitions from annotated source files
4//! in any of the 9 supported authoring languages. Pure text processing,
5//! no language runtime needed.
6
7use std::{
8    collections::HashMap,
9    fs,
10    path::{Path, PathBuf},
11};
12
13use anyhow::{Context, Result};
14use regex::Regex;
15use tracing::info;
16
17use super::init::Language;
18use crate::schema::intermediate::{IntermediateQuery, IntermediateSchema, IntermediateType};
19
20mod csharp;
21mod go;
22mod java;
23mod kotlin;
24mod python;
25mod rust;
26mod scala;
27mod swift;
28#[cfg(test)]
29mod tests;
30mod typescript;
31
32use self::{
33    csharp::CSharpExtractor, go::GoExtractor, java::JavaExtractor, kotlin::KotlinExtractor,
34    python::PythonExtractor, rust::RustExtractor, scala::ScalaExtractor, swift::SwiftExtractor,
35    typescript::TypeScriptExtractor,
36};
37
38// =============================================================================
39// Core types
40// =============================================================================
41
42/// Extracted schema from a single source file.
43struct ExtractedSchema {
44    types:   Vec<IntermediateType>,
45    queries: Vec<IntermediateQuery>,
46}
47
48/// Trait for language-specific schema extraction.
49trait SchemaExtractor {
50    fn extract(&self, source: &str) -> Result<ExtractedSchema>;
51}
52
53// =============================================================================
54// Public API
55// =============================================================================
56
57/// Run the extract command.
58///
59/// # Errors
60///
61/// Returns an error if no source files are found, file I/O fails, or schema extraction
62/// encounters invalid syntax in the input files.
63pub fn run(
64    inputs: &[String],
65    language_override: Option<&str>,
66    recursive: bool,
67    output: &str,
68) -> Result<()> {
69    let override_lang = language_override
70        .map(|s| s.parse::<Language>().map_err(|e| anyhow::anyhow!(e)))
71        .transpose()?;
72
73    let mut all_types: Vec<IntermediateType> = Vec::new();
74    let mut all_queries: Vec<IntermediateQuery> = Vec::new();
75
76    let files = collect_files(inputs, recursive)?;
77
78    if files.is_empty() {
79        anyhow::bail!("No source files found in the provided input paths");
80    }
81
82    for file in &files {
83        let lang = match override_lang {
84            Some(l) => l,
85            None => detect_language(file)?,
86        };
87
88        let source = fs::read_to_string(file)
89            .with_context(|| format!("Failed to read {}", file.display()))?;
90
91        let extracted = dispatch_extractor(lang, &source)
92            .with_context(|| format!("Failed to extract from {}", file.display()))?;
93
94        for t in extracted.types {
95            if !all_types.iter().any(|existing| existing.name == t.name) {
96                all_types.push(t);
97            }
98        }
99        for q in extracted.queries {
100            if !all_queries.iter().any(|existing| existing.name == q.name) {
101                all_queries.push(q);
102            }
103        }
104    }
105
106    let schema = IntermediateSchema {
107        version: "2.0.0".to_string(),
108        types: all_types,
109        queries: all_queries,
110        ..IntermediateSchema::default()
111    };
112
113    let json = serde_json::to_string_pretty(&schema).context("Failed to serialize schema")?;
114    fs::write(output, &json).with_context(|| format!("Failed to write {output}"))?;
115
116    info!("Extracted {} types and {} queries", schema.types.len(), schema.queries.len());
117    println!(
118        "Extracted {} types, {} queries → {}",
119        schema.types.len(),
120        schema.queries.len(),
121        output,
122    );
123
124    Ok(())
125}
126
127// =============================================================================
128// File collection
129// =============================================================================
130
131fn collect_files(inputs: &[String], recursive: bool) -> Result<Vec<PathBuf>> {
132    let mut files = Vec::new();
133    for input in inputs {
134        let path = PathBuf::from(input);
135        if path.is_file() {
136            files.push(path);
137        } else if path.is_dir() {
138            if recursive {
139                collect_dir_recursive(&path, &mut files)?;
140            } else {
141                collect_dir_flat(&path, &mut files)?;
142            }
143        } else {
144            anyhow::bail!("Path does not exist: {input}");
145        }
146    }
147    Ok(files)
148}
149
150fn collect_dir_recursive(dir: &Path, files: &mut Vec<PathBuf>) -> Result<()> {
151    for entry in walkdir::WalkDir::new(dir)
152        .follow_links(true)
153        .into_iter()
154        .filter_map(std::result::Result::ok)
155    {
156        let path = entry.path();
157        if path.is_file() && is_known_extension(path) {
158            files.push(path.to_path_buf());
159        }
160    }
161    Ok(())
162}
163
164fn collect_dir_flat(dir: &Path, files: &mut Vec<PathBuf>) -> Result<()> {
165    for entry in fs::read_dir(dir).context("Failed to read directory")? {
166        let entry = entry?;
167        let path = entry.path();
168        if path.is_file() && is_known_extension(&path) {
169            files.push(path);
170        }
171    }
172    Ok(())
173}
174
175fn is_known_extension(path: &Path) -> bool {
176    path.extension()
177        .and_then(|e| e.to_str())
178        .and_then(Language::from_extension)
179        .is_some()
180}
181
182fn detect_language(path: &Path) -> Result<Language> {
183    let ext = path
184        .extension()
185        .and_then(|e| e.to_str())
186        .ok_or_else(|| anyhow::anyhow!("File has no extension: {}", path.display()))?;
187    Language::from_extension(ext)
188        .ok_or_else(|| anyhow::anyhow!("Unsupported file extension: .{ext}"))
189}
190
191fn dispatch_extractor(lang: Language, source: &str) -> Result<ExtractedSchema> {
192    match lang {
193        Language::Python => PythonExtractor.extract(source),
194        Language::TypeScript => TypeScriptExtractor.extract(source),
195        Language::Rust => RustExtractor.extract(source),
196        Language::Java => JavaExtractor.extract(source),
197        Language::Kotlin => KotlinExtractor.extract(source),
198        Language::Go => GoExtractor.extract(source),
199        Language::CSharp => CSharpExtractor.extract(source),
200        Language::Swift => SwiftExtractor.extract(source),
201        Language::Scala => ScalaExtractor.extract(source),
202        Language::Php => anyhow::bail!(
203            "PHP extraction is handled by the PHP SDK binary (`vendor/bin/fraiseql export`). Run that first to produce schema.json, then use `fraiseql compile`."
204        ),
205    }
206}
207
208// =============================================================================
209// Shared utilities
210// =============================================================================
211
212/// Parse annotation parameters from a string like `key = "value", key2 = true`.
213fn parse_annotation_params(s: &str) -> HashMap<String, String> {
214    let mut params = HashMap::new();
215    // Match key = "value", key: "value", key = true, key = false, key = ClassName
216    // Also matches typeof(X) for C# and classOf[X] for Scala
217    let re = Regex::new(
218        r#"(\w+)\s*[=:]\s*(?:"([^"]*)"|'([^']*)'|(true|false)|(\w[\w.<>\[\]:]*(?:::class|\.class|\.self)?(?:\([^)]*\))?))"#,
219    )
220    .expect("valid regex");
221
222    for cap in re.captures_iter(s) {
223        let key = cap[1].to_string();
224        let value = if let Some(m) = cap.get(2) {
225            m.as_str().to_string()
226        } else if let Some(m) = cap.get(3) {
227            m.as_str().to_string()
228        } else if let Some(m) = cap.get(4) {
229            m.as_str().to_string()
230        } else if let Some(m) = cap.get(5) {
231            strip_class_ref(m.as_str())
232        } else {
233            continue;
234        };
235        params.insert(key, value);
236    }
237    params
238}
239
240/// Strip language-specific class references to get the bare type name.
241fn strip_class_ref(s: &str) -> String {
242    // Post.class → Post, classOf[Post] → Post, typeof(Post) → Post,
243    // Post.self → Post, Post::class → Post
244    let s = s
245        .trim_end_matches(".class")
246        .trim_end_matches(".self")
247        .trim_end_matches("::class");
248
249    // classOf[Post] → Post
250    if let Some(inner) = s.strip_prefix("classOf[").and_then(|s| s.strip_suffix(']')) {
251        return inner.to_string();
252    }
253    // typeof(Post) → Post
254    if let Some(inner) = s.strip_prefix("typeof(").and_then(|s| s.strip_suffix(')')) {
255        return inner.to_string();
256    }
257
258    s.to_string()
259}
260
261/// Convert `camelCase` or `PascalCase` to `snake_case`.
262fn to_snake_case(s: &str) -> String {
263    let mut result = String::with_capacity(s.len() + 4);
264    for (i, ch) in s.chars().enumerate() {
265        if ch.is_uppercase() {
266            if i > 0 {
267                result.push('_');
268            }
269            result.push(ch.to_lowercase().next().unwrap_or(ch));
270        } else {
271            result.push(ch);
272        }
273    }
274    result
275}
276
277/// Map a language-specific type string to (GraphQL type, nullable).
278fn map_type(lang: Language, type_str: &str) -> (String, bool) {
279    // Handle nullable wrappers first
280    let (inner, nullable) = extract_nullable(lang, type_str);
281    let graphql = map_primitive_type(&inner);
282    (graphql, nullable)
283}
284
285fn extract_nullable(lang: Language, type_str: &str) -> (String, bool) {
286    let trimmed = type_str.trim();
287
288    match lang {
289        Language::Python => {
290            // `str | None` or `int | None`
291            if let Some(base) =
292                trimmed.strip_suffix("| None").or_else(|| trimmed.strip_suffix("|None"))
293            {
294                return (base.trim().to_string(), true);
295            }
296            // `Optional[str]`
297            if let Some(inner) = trimmed.strip_prefix("Optional[").and_then(|s| s.strip_suffix(']'))
298            {
299                return (inner.trim().to_string(), true);
300            }
301        },
302        Language::Rust => {
303            if let Some(inner) = trimmed.strip_prefix("Option<").and_then(|s| s.strip_suffix('>')) {
304                return (inner.trim().to_string(), true);
305            }
306        },
307        Language::Kotlin | Language::Swift | Language::CSharp => {
308            if let Some(base) = trimmed.strip_suffix('?') {
309                return (base.to_string(), true);
310            }
311        },
312        Language::Go => {
313            if let Some(base) = trimmed.strip_prefix('*') {
314                return (base.to_string(), true);
315            }
316        },
317        Language::Scala => {
318            if let Some(inner) = trimmed.strip_prefix("Option[").and_then(|s| s.strip_suffix(']')) {
319                return (inner.trim().to_string(), true);
320            }
321        },
322        Language::Java => {
323            // Nullable is handled via @Nullable annotation, not type syntax
324        },
325        Language::TypeScript => {
326            // TypeScript uses explicit `nullable: true` in the object literal
327        },
328        Language::Php => {
329            // PHP uses ?Type prefix for nullable
330            if let Some(base) = trimmed.strip_prefix('?') {
331                return (base.to_string(), true);
332            }
333        },
334    }
335
336    (trimmed.to_string(), false)
337}
338
339/// Derive query name from interface/class name.
340/// Posts → posts, PostById → post, Authors → authors, AuthorById → author, Tags → tags
341fn derive_query_name(interface_name: &str) -> String {
342    // "ById" suffix → singular, without ById
343    if let Some(base) = interface_name.strip_suffix("ById") {
344        return to_snake_case(base).to_lowercase();
345    }
346    // Otherwise just lowercase the whole thing
347    to_snake_case(interface_name).to_lowercase()
348}
349
350fn map_primitive_type(s: &str) -> String {
351    match s {
352        // Integer types
353        "int" | "i32" | "i64" | "Int" | "Integer" | "long" | "Long" | "int32" | "int64" => {
354            "Int".to_string()
355        },
356        // Float types
357        "float" | "f32" | "f64" | "Float" | "Double" | "double" | "decimal" | "Decimal"
358        | "Float32" | "Float64" => "Float".to_string(),
359        // Boolean types
360        "bool" | "boolean" | "Boolean" | "Bool" | "BIT" => "Boolean".to_string(),
361        // String types
362        "str" | "String" | "string" | "&str" | "NVARCHAR" => "String".to_string(),
363        // ID type
364        "ID" => "ID".to_string(),
365        // DateTime
366        "DateTime" | "Instant" | "LocalDateTime" | "ZonedDateTime" | "Date" => {
367            "DateTime".to_string()
368        },
369        // Unknown → assume it's a custom type name and pass through
370        other => other.to_string(),
371    }
372}