Skip to main content

yallm_macros/
lib.rs

1//! Procedural macros for yallm, including OpenAPI type generation.
2//!
3//! # Example
4//!
5//! ```ignore
6//! yallm_macros::include_openapi! {
7//!     url = "https://example.com/openapi.yml",
8//!     root_types = ["CreateChatCompletionRequest"],
9//! }
10//! ```
11
12use std::collections::HashSet;
13
14use proc_macro::TokenStream;
15use quote::quote;
16use regex::Regex;
17use serde_json::Value;
18use syn::parse::{Parse, ParseStream};
19use syn::{Ident, LitStr, Token, braced, bracketed};
20
21struct OpenApiConfig {
22    url: Option<String>,
23    local_file: Option<String>,
24    root_types: Vec<String>,
25    extra_definitions: Option<String>,
26    debug_schema_path: Option<String>,
27}
28
29impl Parse for OpenApiConfig {
30    fn parse(input: ParseStream) -> syn::Result<Self> {
31        let mut url = None;
32        let mut local_file = None;
33        let mut root_types = Vec::new();
34        let mut extra_definitions = None;
35        let mut debug_schema_path = None;
36
37        while !input.is_empty() {
38            let key: Ident = input.parse()?;
39            input.parse::<Token![=]>()?;
40
41            match key.to_string().as_str() {
42                "url" => {
43                    let lit: LitStr = input.parse()?;
44                    url = Some(lit.value());
45                }
46                "local_file" => {
47                    let lit: LitStr = input.parse()?;
48                    local_file = Some(lit.value());
49                }
50                "root_types" => {
51                    let content;
52                    bracketed!(content in input);
53                    while !content.is_empty() {
54                        let lit: LitStr = content.parse()?;
55                        root_types.push(lit.value());
56                        if content.peek(Token![,]) {
57                            content.parse::<Token![,]>()?;
58                        }
59                    }
60                }
61                "extra_definitions" => {
62                    if input.peek(LitStr) {
63                        // 兼容旧的字符串模式
64                        let lit: LitStr = input.parse()?;
65                        extra_definitions = Some(lit.value());
66                    } else if input.peek(syn::token::Brace) {
67                        // 新的直接 JSON 模式
68                        let content;
69                        braced!(content in input);
70                        let tokens: proc_macro2::TokenStream = content.parse()?;
71                        let json_str = format!("{{{}}}", tokens);
72                        // 验证是有效的 JSON
73                        let _: serde_json::Value =
74                            serde_json::from_str(&json_str).map_err(|e| {
75                                syn::Error::new(key.span(), format!("invalid JSON: {}", e))
76                            })?;
77                        extra_definitions = Some(json_str);
78                    } else {
79                        return Err(syn::Error::new(
80                            input.span(),
81                            "expected string literal or JSON object",
82                        ));
83                    }
84                }
85                "debug_schema_path" => {
86                    let lit: LitStr = input.parse()?;
87                    debug_schema_path = Some(lit.value());
88                }
89                _ => {
90                    return Err(syn::Error::new(key.span(), format!("unknown key: {}", key)));
91                }
92            }
93
94            if input.peek(Token![,]) {
95                input.parse::<Token![,]>()?;
96            }
97        }
98
99        // URL is required unless local_file is provided
100        if url.is_none() && local_file.is_none() {
101            return Err(syn::Error::new(
102                input.span(),
103                "missing `url` or `local_file`",
104            ));
105        }
106
107        Ok(OpenApiConfig {
108            url,
109            local_file,
110            root_types,
111            extra_definitions,
112            debug_schema_path,
113        })
114    }
115}
116
117/// Include OpenAPI-generated types inline.
118///
119/// This macro fetches the OpenAPI spec at compile time (with HTTP caching),
120/// generates Rust types, and includes them directly in the source.
121///
122/// # Example
123///
124/// ```ignore
125/// include_openapi! {
126///     url = "https://example.com/openapi.yml",
127///     root_types = ["ChatCompletionRequestMessage"],
128///     extra_definitions = {
129///         "MissingType": {
130///             "type": "object",
131///             "properties": { "name": { "type": "string" } }
132///         }
133///     },
134/// }
135/// ```
136#[proc_macro]
137pub fn include_openapi(input: TokenStream) -> TokenStream {
138    let config = syn::parse_macro_input!(input as OpenApiConfig);
139
140    let code = match generate_types(&config) {
141        Ok(code) => code,
142        Err(e) => {
143            return syn::Error::new(proc_macro2::Span::call_site(), e.to_string())
144                .to_compile_error()
145                .into();
146        }
147    };
148
149    let tokens: proc_macro2::TokenStream = match code.parse() {
150        Ok(t) => t,
151        Err(e) => {
152            return syn::Error::new(
153                proc_macro2::Span::call_site(),
154                format!("Failed to parse generated code: {}", e),
155            )
156            .to_compile_error()
157            .into();
158        }
159    };
160
161    quote! { #tokens }.into()
162}
163
164fn generate_types(config: &OpenApiConfig) -> Result<String, Box<dyn std::error::Error>> {
165    // Try local file first, then URL with caching
166    let spec_yaml = if let Some(ref local) = config.local_file {
167        let manifest_dir = std::env::var("CARGO_MANIFEST_DIR")?;
168        let local_path = std::path::Path::new(&manifest_dir).join(local);
169        if local_path.exists() {
170            std::fs::read_to_string(&local_path)?
171        } else if let Some(ref url) = config.url {
172            fetch_with_cache(url)?
173        } else {
174            return Err(format!("Local file not found: {}", local_path.display()).into());
175        }
176    } else if let Some(ref url) = config.url {
177        fetch_with_cache(url)?
178    } else {
179        return Err("No URL or local file specified".into());
180    };
181
182    let spec_yaml = preprocess_yaml(&spec_yaml);
183    let spec: Value = serde_yaml_ng::from_str(&spec_yaml)?;
184
185    let mut schemas = spec
186        .get("components")
187        .and_then(|c| c.get("schemas"))
188        .ok_or("No components/schemas in OpenAPI spec")?
189        .clone();
190
191    convert_openapi_to_json_schema(&mut schemas);
192
193    // Extract inline type enums to avoid typify merging them
194    extract_inline_type_enums(&mut schemas);
195
196    // Add extra definitions if provided
197    if let Some(ref extra) = config.extra_definitions {
198        let extra_defs: serde_json::Map<String, Value> = serde_json::from_str(extra)?;
199        if let Value::Object(ref mut map) = schemas {
200            for (k, v) in extra_defs {
201                map.insert(k, v);
202            }
203        }
204    }
205
206    let root_refs: Vec<&str> = config.root_types.iter().map(|s| s.as_str()).collect();
207    let schemas = filter_schemas(schemas, &root_refs);
208
209    let mut json_schema = serde_json::json!({
210        "$schema": "http://json-schema.org/draft-07/schema#",
211        "definitions": schemas,
212    });
213    convert_openapi_to_json_schema(&mut json_schema);
214
215    // Write debug schema if requested
216    if let Some(ref debug_path) = config.debug_schema_path {
217        let manifest_dir = std::env::var("CARGO_MANIFEST_DIR")?;
218        let debug_file = std::path::Path::new(&manifest_dir).join(debug_path);
219        let formatted = serde_json::to_string_pretty(&json_schema)?;
220        std::fs::write(&debug_file, formatted)?;
221    }
222
223    let mut type_space = typify::TypeSpace::new(
224        typify::TypeSpaceSettings::default().with_derive("PartialEq".to_string()),
225    );
226
227    let root_schema: schemars::schema::RootSchema = serde_json::from_value(json_schema.clone())
228        .map_err(|e| format!("Failed to parse JSON schema: {}", e,))?;
229    type_space
230        .add_root_schema(root_schema)
231        .map_err(|e| format!("Failed to add root schema to type space: {}", e))?;
232
233    Ok(type_space.to_stream().to_string())
234}
235
236fn fetch_with_cache(url: &str) -> Result<String, Box<dyn std::error::Error>> {
237    use http_cache_reqwest::{CACacheManager, Cache, CacheMode, HttpCache, HttpCacheOptions};
238    use reqwest_middleware::ClientBuilder;
239
240    let cache_dir = resolve_cache_dir()?;
241
242    let rt = tokio::runtime::Runtime::new()?;
243
244    rt.block_on(async {
245        let client = ClientBuilder::new(reqwest::Client::new())
246            .with(Cache(HttpCache {
247                mode: CacheMode::Default,
248                manager: CACacheManager { path: cache_dir },
249                options: HttpCacheOptions::default(),
250            }))
251            .build();
252
253        let response = client.get(url).send().await?;
254        let text = response.text().await?;
255        Ok(text)
256    })
257}
258
259fn resolve_cache_dir() -> Result<std::path::PathBuf, Box<dyn std::error::Error>> {
260    let mut candidates: Vec<std::path::PathBuf> = Vec::new();
261
262    if let Ok(dir) = std::env::var("YALLM_CACHE_DIR") {
263        candidates.push(std::path::PathBuf::from(dir));
264    }
265
266    if let Ok(dir) = std::env::var("CARGO_TARGET_DIR") {
267        candidates.push(std::path::PathBuf::from(dir).join("yallm-cache"));
268    }
269
270    if let Some(dir) = dirs::cache_dir() {
271        candidates.push(dir.join("yallm"));
272    }
273
274    candidates.push(std::env::temp_dir().join("yallm-cache"));
275
276    for candidate in candidates {
277        if ensure_writable_dir(&candidate).is_ok() {
278            return Ok(candidate);
279        }
280    }
281
282    Err("Failed to create cache directory for OpenAPI spec".into())
283}
284
285fn ensure_writable_dir(path: &std::path::Path) -> std::io::Result<()> {
286    use std::io::Write;
287
288    std::fs::create_dir_all(path)?;
289    let unique = format!(
290        ".yallm_cache_write_test_{}_{}",
291        std::process::id(),
292        std::time::SystemTime::now()
293            .duration_since(std::time::UNIX_EPOCH)
294            .unwrap_or_default()
295            .as_nanos()
296    );
297    let test_path = path.join(unique);
298    let mut file = std::fs::OpenOptions::new()
299        .write(true)
300        .create_new(true)
301        .open(&test_path)?;
302    file.write_all(b"ok")?;
303    std::fs::remove_file(test_path)?;
304    Ok(())
305}
306
307// ============================================================================
308// OpenAPI to JSON Schema conversion utilities
309// ============================================================================
310
311/// Preprocess YAML to fix problematic values.
312fn preprocess_yaml(yaml: &str) -> String {
313    let re = Regex::new(r"minimum:\s*-\d{15,}").unwrap();
314    let yaml = re.replace_all(yaml, "minimum: -2147483648").to_string();
315
316    let re = Regex::new(r"maximum:\s*\d{15,}").unwrap();
317    re.replace_all(&yaml, "maximum: 2147483647").to_string()
318}
319
320/// Convert OpenAPI schema to JSON Schema format.
321fn convert_openapi_to_json_schema(value: &mut Value) {
322    match value {
323        Value::Object(map) => {
324            // Remove x- extension fields
325            let keys_to_remove: Vec<String> = map
326                .keys()
327                .filter(|k| k.starts_with("x-"))
328                .cloned()
329                .collect();
330            for key in keys_to_remove {
331                map.remove(&key);
332            }
333
334            // Convert $ref paths
335            if let Some(Value::String(ref_path)) = map.get_mut("$ref")
336                && ref_path.starts_with("#/components/schemas/")
337            {
338                *ref_path = ref_path.replace("#/components/schemas/", "#/definitions/");
339            }
340
341            // For object types with properties, identify nullable properties BEFORE simplification
342            // and remove them from required
343            let nullable_props: HashSet<String> = if map.get("type")
344                == Some(&Value::String("object".to_string()))
345            {
346                if let Some(Value::Object(props)) = map.get("properties") {
347                    props
348                        .iter()
349                        .filter_map(|(name, prop_schema)| {
350                            if let Value::Object(prop_obj) = prop_schema {
351                                // Check if property has anyOf with null type
352                                if let Some(Value::Array(any_of)) = prop_obj.get("anyOf") {
353                                    let has_null = any_of.iter().any(|v| {
354                                        matches!(v, Value::Object(m) if m.get("type") == Some(&Value::String("null".to_string())))
355                                    });
356                                    if has_null {
357                                        return Some(name.clone());
358                                    }
359                                }
360                                // Check if property has default: null (indicates nullable)
361                                if prop_obj.get("default") == Some(&Value::Null) {
362                                    return Some(name.clone());
363                                }
364                            }
365                            None
366                        })
367                        .collect()
368                } else {
369                    HashSet::new()
370                }
371            } else {
372                HashSet::new()
373            };
374
375            // Remove nullable properties from required
376            if !nullable_props.is_empty()
377                && let Some(Value::Array(required)) = map.get_mut("required")
378            {
379                required.retain(|v| {
380                    if let Value::String(s) = v {
381                        !nullable_props.contains(s)
382                    } else {
383                        true
384                    }
385                });
386            }
387
388            // Handle anyOf with null type - simplify to single type
389            let replacement = if let Some(Value::Array(any_of)) = map.get("anyOf") {
390                let non_null: Vec<&Value> = any_of
391                    .iter()
392                    .filter(|v| {
393                        !matches!(v, Value::Object(m) if m.get("type") == Some(&Value::String("null".to_string())))
394                    })
395                    .collect();
396
397                let has_null = any_of.len() != non_null.len();
398
399                if has_null && non_null.len() == 1 {
400                    if let Value::Object(inner) = non_null[0] {
401                        Some(inner.clone())
402                    } else {
403                        None
404                    }
405                } else {
406                    None
407                }
408            } else {
409                None
410            };
411
412            if let Some(inner) = replacement {
413                map.remove("anyOf");
414                for (k, v) in inner {
415                    map.insert(k, v);
416                }
417            }
418
419            // Handle OpenAPI 3.0 exclusiveMinimum/exclusiveMaximum boolean format
420            if let Some(Value::Bool(true)) = map.get("exclusiveMinimum") {
421                if let Some(min_val) = map.remove("minimum") {
422                    map.insert("exclusiveMinimum".to_string(), min_val);
423                } else {
424                    map.remove("exclusiveMinimum");
425                }
426            } else if let Some(Value::Bool(false)) = map.get("exclusiveMinimum") {
427                map.remove("exclusiveMinimum");
428            }
429
430            if let Some(Value::Bool(true)) = map.get("exclusiveMaximum") {
431                if let Some(max_val) = map.remove("maximum") {
432                    map.insert("exclusiveMaximum".to_string(), max_val);
433                } else {
434                    map.remove("exclusiveMaximum");
435                }
436            } else if let Some(Value::Bool(false)) = map.get("exclusiveMaximum") {
437                map.remove("exclusiveMaximum");
438            }
439
440            // Handle nullable: true
441            if let Some(Value::Bool(true)) = map.remove("nullable") {
442                if let Some(type_val) = map.get("type").cloned() {
443                    // If type exists, convert to array with null
444                    // e.g., "string" -> ["string", "null"]
445                    match type_val {
446                        Value::String(t) => {
447                            map.insert(
448                                "type".to_string(),
449                                Value::Array(vec![
450                                    Value::String(t),
451                                    Value::String("null".to_string()),
452                                ]),
453                            );
454                        }
455                        Value::Array(mut arr) => {
456                            // Already an array, just add "null" if not present
457                            if !arr.contains(&Value::String("null".to_string())) {
458                                arr.push(Value::String("null".to_string()));
459                            }
460                            map.insert("type".to_string(), Value::Array(arr));
461                        }
462                        _ => {}
463                    }
464                } else if let Some(Value::String(_)) = map.get("$ref") {
465                    // Has $ref but no type, use anyOf
466                    let ref_val = map.remove("$ref").unwrap();
467                    let ref_schema = serde_json::json!({"$ref": ref_val});
468                    let null_schema = serde_json::json!({"type": "null"});
469                    map.insert(
470                        "anyOf".to_string(),
471                        Value::Array(vec![ref_schema, null_schema]),
472                    );
473                } else {
474                    // No type and no $ref, just set type to null
475                    map.insert("type".to_string(), Value::String("null".to_string()));
476                }
477            }
478
479            // Remove OpenAPI-specific fields
480            map.remove("discriminator");
481            map.remove("example");
482            map.remove("examples");
483            map.remove("externalDocs");
484            map.remove("xml");
485            map.remove("nullable");
486
487            // Convert const to enum with single value (more compatible)
488            if let Some(const_val) = map.remove("const") {
489                map.insert("enum".to_string(), Value::Array(vec![const_val]));
490            }
491
492            // Remove default: null when type is not nullable (causes validation errors)
493            if let Some(Value::Null) = map.get("default") {
494                let type_val = map.get("type");
495                let is_nullable = match type_val {
496                    Some(Value::Array(arr)) => arr.contains(&Value::String("null".to_string())),
497                    Some(Value::String(s)) => s == "null",
498                    _ => false,
499                };
500                if !is_nullable {
501                    map.remove("default");
502                }
503            }
504
505            // Relax overly generic string titles that collide across schemas.
506            if let Some(Value::String(title)) = map.get("title") {
507                let is_string = match map.get("type") {
508                    Some(Value::String(t)) => t == "string",
509                    Some(Value::Array(arr)) => arr
510                        .iter()
511                        .any(|v| matches!(v, Value::String(s) if s == "string")),
512                    _ => false,
513                };
514                if is_string {
515                    if title == "Id" {
516                        map.remove("pattern");
517                    } else if title == "Name" {
518                        map.remove("enum");
519                        map.remove("minLength");
520                        map.remove("maxLength");
521                    }
522                }
523            }
524
525            // Recurse
526            for (_, v) in map.iter_mut() {
527                convert_openapi_to_json_schema(v);
528            }
529        }
530        Value::Array(arr) => {
531            for item in arr.iter_mut() {
532                convert_openapi_to_json_schema(item);
533            }
534        }
535        _ => {}
536    }
537}
538
539/// Collect all $ref references from a JSON value.
540fn collect_refs(value: &Value, refs: &mut HashSet<String>) {
541    match value {
542        Value::Object(map) => {
543            if let Some(Value::String(ref_path)) = map.get("$ref") {
544                // Support both OpenAPI format and JSON Schema format
545                if let Some(name) = ref_path
546                    .strip_prefix("#/definitions/")
547                    .or_else(|| ref_path.strip_prefix("#/components/schemas/"))
548                {
549                    refs.insert(name.to_string());
550                }
551            }
552            for v in map.values() {
553                collect_refs(v, refs);
554            }
555        }
556        Value::Array(arr) => {
557            for item in arr {
558                collect_refs(item, refs);
559            }
560        }
561        _ => {}
562    }
563}
564
565/// Filter schemas to only include root types and their transitive dependencies.
566fn filter_schemas(schemas: Value, root_types: &[&str]) -> Value {
567    let schemas_map = match &schemas {
568        Value::Object(map) => map,
569        _ => return schemas,
570    };
571
572    let mut needed: HashSet<String> = root_types.iter().map(|s| s.to_string()).collect();
573    let mut to_process: Vec<String> = root_types.iter().map(|s| s.to_string()).collect();
574
575    while let Some(type_name) = to_process.pop() {
576        if let Some(schema) = schemas_map.get(&type_name) {
577            let mut refs = HashSet::new();
578            collect_refs(schema, &mut refs);
579            for r in refs {
580                if needed.insert(r.clone()) {
581                    to_process.push(r);
582                }
583            }
584        }
585    }
586
587    let filtered: serde_json::Map<String, Value> = schemas_map
588        .iter()
589        .filter(|(k, _)| needed.contains(*k))
590        .map(|(k, v)| (k.clone(), v.clone()))
591        .collect();
592
593    Value::Object(filtered)
594}
595
596/// Extract inline `type` enum fields into separate named definitions.
597///
598/// This prevents typify from merging all `type` enum fields into a single enum,
599/// which would cause serialization/deserialization issues.
600fn extract_inline_type_enums(schemas: &mut Value) {
601    let mut new_definitions: serde_json::Map<String, Value> = serde_json::Map::new();
602
603    if let Value::Object(schemas_map) = schemas {
604        // Collect modifications first
605        let mut modifications: Vec<(String, String)> = Vec::new();
606
607        for (type_name, schema) in schemas_map.iter() {
608            if let Value::Object(obj) = schema
609                && let Some(Value::Object(props)) = obj.get("properties")
610                && let Some(Value::Object(type_prop)) = props.get("type")
611            {
612                // Check if it's a single-value enum
613                if let Some(Value::Array(enum_vals)) = type_prop.get("enum")
614                    && enum_vals.len() == 1
615                {
616                    // Create a unique type name
617                    let unique_type_name = format!("{}Type", type_name);
618                    modifications.push((type_name.clone(), unique_type_name.clone()));
619
620                    // Create the new definition
621                    let mut new_def = type_prop.clone();
622                    new_def.insert("title".to_string(), Value::String(unique_type_name.clone()));
623                    new_definitions.insert(unique_type_name, Value::Object(new_def));
624                }
625            }
626        }
627
628        // Apply modifications
629        for (type_name, unique_type_name) in modifications {
630            if let Some(Value::Object(obj)) = schemas_map.get_mut(&type_name)
631                && let Some(Value::Object(props)) = obj.get_mut("properties")
632                && props.contains_key("type")
633            {
634                // Replace inline enum with $ref
635                props.insert(
636                    "type".to_string(),
637                    serde_json::json!({
638                        "$ref": format!("#/definitions/{}", unique_type_name)
639                    }),
640                );
641            }
642        }
643
644        // Add new definitions
645        for (name, def) in new_definitions {
646            schemas_map.insert(name, def);
647        }
648    }
649}