use super::{Formatter, HeaderMetadata, Import, Section};
use std::collections::BTreeMap;
#[derive(Debug, Clone)]
pub struct PythonFormatter;
impl PythonFormatter {
#[must_use]
pub const fn new() -> Self {
Self
}
}
impl Default for PythonFormatter {
fn default() -> Self {
Self::new()
}
}
impl Formatter for PythonFormatter {
fn format_header(&self, metadata: &HeaderMetadata) -> String {
let mut header = String::new();
header.push_str("#!/usr/bin/env python3\n");
header.push_str("# ruff: noqa: EXE001, I001\n");
if metadata.auto_generated {
header.push_str("# DO NOT EDIT - Auto-generated by Spikard CLI\n");
if let Some(schema_file) = &metadata.schema_file {
header.push_str(&format!("# Schema: {schema_file}\n"));
}
if let Some(version) = &metadata.generator_version {
header.push_str(&format!("# Generator: Spikard {version}\n"));
}
}
header.push_str("\n\"\"\"GraphQL types generated from schema.\"\"\"\n");
header
}
fn format_imports(&self, imports: &[Import]) -> String {
if imports.is_empty() {
return String::new();
}
let mut future_imports = Vec::new();
let mut stdlib_imports = BTreeMap::new();
let mut third_party_imports = BTreeMap::new();
let mut local_imports = BTreeMap::new();
let stdlib_modules = [
"abc",
"argparse",
"array",
"asyncio",
"bisect",
"builtins",
"calendar",
"cmath",
"cmd",
"code",
"codeop",
"collections",
"colorsys",
"compileall",
"concurrent",
"configparser",
"contextlib",
"contextvars",
"copy",
"copyreg",
"cprofile",
"csv",
"ctypes",
"curses",
"dataclasses",
"datetime",
"dbm",
"decimal",
"difflib",
"dis",
"doctest",
"email",
"encodings",
"enum",
"errno",
"faulthandler",
"fcntl",
"filecmp",
"fileinput",
"fnmatch",
"fractions",
"ftplib",
"functools",
"gc",
"getopt",
"getpass",
"gettext",
"glob",
"grp",
"gzip",
"hashlib",
"heapq",
"hmac",
"html",
"http",
"idlelib",
"imaplib",
"imghdr",
"imp",
"importlib",
"inspect",
"io",
"ipaddress",
"itertools",
"json",
"keyword",
"lib2to3",
"linecache",
"locale",
"logging",
"lzma",
"mailbox",
"mailcap",
"marshal",
"math",
"mimetypes",
"mmap",
"modulefinder",
"msilib",
"msvcrt",
"multiprocessing",
"netrc",
"nis",
"nntplib",
"numbers",
"operator",
"optparse",
"os",
"ossaudiodev",
"parser",
"pathlib",
"pdb",
"pickle",
"pickletools",
"pipes",
"pkgutil",
"platform",
"plistlib",
"poplib",
"posix",
"posixpath",
"pprint",
"profile",
"pstats",
"pty",
"pwd",
"py_compile",
"pyclbr",
"pydoc",
"queue",
"quopri",
"random",
"re",
"readline",
"reprlib",
"resource",
"rlcompleter",
"runpy",
"sched",
"secrets",
"select",
"selectors",
"shelve",
"shlex",
"shutil",
"signal",
"site",
"smtpd",
"smtplib",
"sndhdr",
"socket",
"socketserver",
"spwd",
"sqlite3",
"ssl",
"stat",
"statistics",
"string",
"stringprep",
"struct",
"subprocess",
"sunau",
"symbol",
"symtable",
"sys",
"sysconfig",
"syslog",
"tabnanny",
"tarfile",
"telnetlib",
"tempfile",
"termios",
"test",
"textwrap",
"threading",
"time",
"timeit",
"tkinter",
"token",
"tokenize",
"trace",
"traceback",
"tracemalloc",
"tty",
"turtle",
"types",
"typing",
"typing_extensions",
"unicodedata",
"unittest",
"urllib",
"uu",
"uuid",
"venv",
"warnings",
"wave",
"weakref",
"webbrowser",
"winreg",
"winsound",
"wsgiref",
"xdrlib",
"xml",
"xmlrpc",
"zipapp",
"zipfile",
"zipimport",
"zlib",
];
for import in imports {
let module_name = import.module.split('.').next().unwrap_or(&import.module);
if module_name == "__future__" {
future_imports.push(import.clone());
} else if stdlib_modules.contains(&module_name) {
stdlib_imports
.entry(import.module.clone())
.or_insert_with(Vec::new)
.push(import.clone());
} else if module_name.starts_with('.') {
local_imports
.entry(import.module.clone())
.or_insert_with(Vec::new)
.push(import.clone());
} else {
third_party_imports
.entry(import.module.clone())
.or_insert_with(Vec::new)
.push(import.clone());
}
}
let mut output = String::new();
if !future_imports.is_empty() {
for import in &future_imports {
output.push_str(&format_python_import(import));
output.push('\n');
}
output.push('\n');
}
if !stdlib_imports.is_empty() {
for imports_vec in stdlib_imports.values() {
for import in imports_vec {
output.push_str(&format_python_import(import));
output.push('\n');
}
}
output.push('\n');
}
if !third_party_imports.is_empty() {
for imports_vec in third_party_imports.values() {
for import in imports_vec {
output.push_str(&format_python_import(import));
output.push('\n');
}
}
output.push('\n');
}
if !local_imports.is_empty() {
for imports_vec in local_imports.values() {
for import in imports_vec {
output.push_str(&format_python_import(import));
output.push('\n');
}
}
}
output.trim_end().to_string()
}
fn format_docstring(&self, content: &str) -> String {
let escaped = content.replace("\"\"\"", r#"\"\"\""#);
format!("\"\"\"{escaped}\"\"\"")
}
fn merge_sections(&self, sections: &[Section]) -> String {
let mut header = String::new();
let mut imports = String::new();
let mut body = String::new();
for section in sections {
match section {
Section::Header(content) => {
if header.is_empty() {
header = content.clone();
}
}
Section::Imports(content) => {
if imports.is_empty() {
imports = content.clone();
}
}
Section::Body(content) => {
if body.is_empty() {
body = content.clone();
}
}
}
}
let mut output = String::new();
if !header.is_empty() {
output.push_str(header.trim_end());
output.push_str("\n\n");
}
if !imports.is_empty() {
output.push_str(imports.trim_end());
output.push_str("\n\n");
}
if !body.is_empty() {
output.push_str(body.trim_end());
output.push('\n');
}
output.trim_end().to_string() + "\n"
}
}
fn format_python_import(import: &Import) -> String {
if import.items.is_empty() {
format!("import {}", import.module)
} else {
let items = import.items.join(", ");
format!("from {} import {}", import.module, items)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format_header_with_metadata() {
let formatter = PythonFormatter::new();
let metadata = HeaderMetadata {
auto_generated: true,
schema_file: Some("schema.graphql".to_string()),
generator_version: Some("0.6.2".to_string()),
};
let header = formatter.format_header(&metadata);
assert!(header.contains("#!/usr/bin/env python3"));
assert!(header.contains("# ruff: noqa: EXE001, I001"));
assert!(header.contains("# DO NOT EDIT - Auto-generated by Spikard CLI"));
assert!(header.contains("# Schema: schema.graphql"));
assert!(header.contains("# Generator: Spikard 0.6.2"));
assert!(header.contains("\"\"\"GraphQL types generated from schema.\"\"\""));
}
#[test]
fn test_format_header_without_metadata() {
let formatter = PythonFormatter::new();
let metadata = HeaderMetadata {
auto_generated: false,
schema_file: None,
generator_version: None,
};
let header = formatter.format_header(&metadata);
assert!(header.contains("#!/usr/bin/env python3"));
assert!(!header.contains("# DO NOT EDIT"));
}
#[test]
fn test_format_imports_empty() {
let formatter = PythonFormatter::new();
let imports = [];
let output = formatter.format_imports(&imports);
assert!(output.is_empty());
}
#[test]
fn test_format_imports_grouped_and_sorted() {
let formatter = PythonFormatter::new();
let imports = vec![
Import::with_items("typing", vec!["List", "Dict"]),
Import::with_items("__future__", vec!["annotations"]),
Import::new("msgspec"),
Import::new("graphql"),
];
let output = formatter.format_imports(&imports);
let lines: Vec<&str> = output.lines().collect();
assert_eq!(lines[0], "from __future__ import annotations");
assert!(lines.contains(&"from typing import List, Dict"));
assert!(lines.contains(&"import graphql"));
assert!(lines.contains(&"import msgspec"));
}
#[test]
fn test_format_imports_simple_module() {
let formatter = PythonFormatter::new();
let imports = vec![Import::new("asyncio")];
let output = formatter.format_imports(&imports);
assert_eq!(output.trim(), "import asyncio");
}
#[test]
fn test_format_imports_with_items() {
let formatter = PythonFormatter::new();
let imports = vec![Import::with_items("typing", vec!["Optional", "Union"])];
let output = formatter.format_imports(&imports);
assert_eq!(output.trim(), "from typing import Optional, Union");
}
#[test]
fn test_format_docstring() {
let formatter = PythonFormatter::new();
let content = "This is a test docstring";
let output = formatter.format_docstring(content);
assert_eq!(output, "\"\"\"This is a test docstring\"\"\"");
}
#[test]
fn test_format_docstring_with_quotes() {
let formatter = PythonFormatter::new();
let content = r#"This says "hello""""#;
let output = formatter.format_docstring(content);
assert!(output.contains(r#"\"\"\""#));
}
#[test]
fn test_merge_sections_in_order() {
let formatter = PythonFormatter::new();
let sections = vec![
Section::Header("#!/usr/bin/env python3\n# Auto-gen".to_string()),
Section::Imports("from typing import List".to_string()),
Section::Body("class MyType:\n pass".to_string()),
];
let output = formatter.merge_sections(§ions);
let lines: Vec<&str> = output.lines().collect();
assert!(lines[0].contains("#!/usr/bin/env python3"));
assert!(lines.iter().any(|l| l.contains("from typing import List")));
assert!(lines.iter().any(|l| l.contains("class MyType")));
}
#[test]
fn test_merge_sections_duplicate_headers() {
let formatter = PythonFormatter::new();
let sections = vec![
Section::Header("#!/usr/bin/env python3".to_string()),
Section::Header("#!/usr/bin/env python3".to_string()),
Section::Body("class MyType:\n pass".to_string()),
];
let output = formatter.merge_sections(§ions);
let header_count = output.matches("#!/usr/bin/env python3").count();
assert_eq!(header_count, 1, "Should not duplicate headers");
}
#[test]
fn test_merge_sections_trailing_newline() {
let formatter = PythonFormatter::new();
let sections = vec![Section::Body("class MyType:\n pass".to_string())];
let output = formatter.merge_sections(§ions);
assert!(output.ends_with('\n'));
assert!(!output.ends_with("\n\n"));
}
}