rsgen_avro/
generator.rs

1use std::collections::{HashMap, VecDeque};
2use std::fs;
3use std::io::prelude::*;
4
5use apache_avro::schema::{ArraySchema, DecimalSchema, MapSchema, Name, RecordField, RecordSchema};
6
7use crate::Schema;
8use crate::error::{Error, Result};
9use crate::templates::*;
10
11/// An input source for generating Rust types.
12pub enum Source<'a> {
13    /// An Avro schema enum from the `apache-avro` crate.
14    Schema(&'a Schema),
15    /// A slice of Avro schema enums from the `apache-avro` crate.
16    Schemas(&'a [Schema]),
17    /// An Avro schema string in json format.
18    SchemaStr(&'a str),
19    /// Pattern for selecting files containing Avro schemas in json format.
20    GlobPattern(&'a str),
21}
22
23/// The main component for generating Rust types from a [`Source`](Source).
24///
25/// It is stateless and can be reused many times.
26#[derive(Debug)]
27pub struct Generator {
28    templater: Templater,
29}
30
31impl Generator {
32    /// Creates a new [`Generator`](Generator) with default configuration.
33    pub fn new() -> Result<Generator> {
34        GeneratorBuilder::new().build()
35    }
36
37    /// Returns a fluid builder for custom [`Generator`](Generator) instantiation.
38    pub fn builder() -> GeneratorBuilder {
39        GeneratorBuilder::new()
40    }
41
42    /// Generates Rust code from an Avro schema [`Source`](Source).
43    /// Writes all generated types to the output.
44    pub fn generate(&self, source: &Source, output: &mut impl Write) -> Result<()> {
45        match source {
46            Source::Schema(schema) => {
47                let mut deps = deps_stack(schema, vec![]);
48                self.gen_in_order(&mut deps, output)?;
49            }
50
51            Source::Schemas(schemas) => {
52                let mut deps = schemas
53                    .iter()
54                    .fold(vec![], |deps, schema| deps_stack(schema, deps));
55
56                self.gen_in_order(&mut deps, output)?;
57            }
58
59            Source::SchemaStr(raw_schema) => {
60                let schema = Schema::parse_str(raw_schema)?;
61                let mut deps = deps_stack(&schema, vec![]);
62                self.gen_in_order(&mut deps, output)?;
63            }
64
65            Source::GlobPattern(pattern) => {
66                let mut raw_schemas = vec![];
67                let mut paths = glob::glob(pattern)?.peekable();
68                if paths.peek().is_none() {
69                    return Err(Error::GlobPattern(glob::PatternError {
70                        pos: 0,
71                        msg: "No files with the given glob pattern were found",
72                    }));
73                }
74                for path in paths {
75                    let path = path.map_err(|e| e.into_error())?;
76                    if !path.is_dir() {
77                        raw_schemas.push(fs::read_to_string(path)?);
78                    }
79                }
80
81                let schemas = &raw_schemas.iter().map(|s| s.as_str()).collect::<Vec<_>>();
82                let schemas = Schema::parse_list(schemas)?;
83                self.generate(&Source::Schemas(&schemas), output)?;
84            }
85        }
86
87        Ok(())
88    }
89
90    /// Given an Avro `schema`:
91    /// * Find its ordered, nested dependencies with `deps_stack(schema)`
92    /// * Pops sub-schemas and generate appropriate Rust types
93    /// * Keeps tracks of nested schema->name with `GenState` mapping
94    /// * Appends generated Rust types to the output
95    fn gen_in_order(&self, deps: &mut Vec<Schema>, output: &mut impl Write) -> Result<()> {
96        let mut gs = GenState::new(deps)?.with_chrono_dates(self.templater.use_chrono_dates);
97
98        if !self.templater.field_overrides.is_empty() {
99            // This rechecks no_eq for all schemas, so only do it if there are actually overrides.
100            gs = gs.with_field_overrides(deps, &self.templater.field_overrides)?;
101        }
102
103        while let Some(s) = deps.pop() {
104            match s {
105                // Simply generate code
106                Schema::Fixed { .. } => {
107                    let code = &self.templater.str_fixed(&s)?;
108                    output.write_all(code.as_bytes())?
109                }
110                Schema::Enum { .. } => {
111                    let code = &self.templater.str_enum(&s)?;
112                    output.write_all(code.as_bytes())?
113                }
114
115                // Generate code with potentially nested types
116                Schema::Record { .. } => {
117                    let code = &self.templater.str_record(&s, &gs)?;
118                    output.write_all(code.as_bytes())?
119                }
120
121                // Register inner type for it to be used as a nested type later
122                Schema::Array(ArraySchema {
123                    items: ref inner, ..
124                }) => {
125                    let type_str = array_type(inner, &gs)?;
126                    gs.put_type(&s, type_str)
127                }
128                Schema::Map(MapSchema {
129                    types: ref inner, ..
130                }) => {
131                    let type_str = map_type(inner, &gs)?;
132                    gs.put_type(&s, type_str)
133                }
134
135                Schema::Union(ref union) => {
136                    // Generate custom enum with potentially nested types
137                    if (union.is_nullable() && union.variants().len() > 2)
138                        || (!union.is_nullable() && !union.variants().is_empty())
139                    {
140                        let code = &self.templater.str_union_enum(&s, &gs)?;
141                        output.write_all(code.as_bytes())?
142                    }
143
144                    // Register inner union for it to be used as a nested type later
145                    let type_str = union_type(union, &gs, true)?;
146                    gs.put_type(&s, type_str)
147                }
148
149                _ => return Err(Error::Schema(format!("Not a valid root schema: {s:?}"))),
150            }
151        }
152
153        Ok(())
154    }
155}
156
157/// Utility function to find the ordered, nested dependencies of an Avro `schema`.
158/// Explores nested `schema`s in a breadth-first fashion, pushing them on a stack at the
159/// same time in order to have them ordered.  It is similar to traversing the `schema`
160/// tree in a post-order fashion.
161fn deps_stack(schema: &Schema, mut deps: Vec<Schema>) -> Vec<Schema> {
162    fn push_unique(deps: &mut Vec<Schema>, s: Schema) {
163        // Check if the schema is already in the stack.
164        // For named types (Record, Enum, Fixed), we check if the name is the same.
165        // This is important because sometimes `apache-avro` produces unequal schema objects
166        // for the same named type (e.g. when resolving references or loading from different files),
167        // which would result in duplicate code generation if we only checked for object equality.
168        // For other types, we fallback to object equality.
169        let existing = deps.iter().position(|d| match (d, &s) {
170            (Schema::Record(r1), Schema::Record(r2)) => r1.name == r2.name,
171            (Schema::Enum(e1), Schema::Enum(e2)) => e1.name == e2.name,
172            (Schema::Fixed(f1), Schema::Fixed(f2)) => f1.name == f2.name,
173            _ => d == &s,
174        });
175
176        if let Some(i) = existing {
177            deps.remove(i);
178        }
179        deps.push(s);
180    }
181
182    let mut q = VecDeque::new();
183
184    q.push_back(schema);
185    while !q.is_empty() {
186        let s = q.pop_front().unwrap();
187
188        match s {
189            // No nested schemas, add them to the result stack
190            Schema::Enum { .. } => push_unique(&mut deps, s.clone()),
191            Schema::Fixed { .. } => push_unique(&mut deps, s.clone()),
192            Schema::Decimal(DecimalSchema { inner, .. })
193                if matches!(inner.as_ref(), Schema::Fixed { .. }) =>
194            {
195                push_unique(&mut deps, s.clone())
196            }
197
198            // Explore the record fields for potentially nested schemas
199            Schema::Record(RecordSchema { fields, .. }) => {
200                push_unique(&mut deps, s.clone());
201
202                let by_pos = fields
203                    .iter()
204                    .map(|f| (f.position, f))
205                    .collect::<HashMap<_, _>>();
206                let mut i = 0;
207                while let Some(RecordField { schema: sr, .. }) = by_pos.get(&i) {
208                    match sr {
209                        // No nested schemas, add them to the result stack
210                        Schema::Fixed { .. } => push_unique(&mut deps, sr.clone()),
211                        Schema::Enum { .. } => push_unique(&mut deps, sr.clone()),
212
213                        // Push to the exploration queue for further checks
214                        Schema::Record { .. } => q.push_back(sr),
215
216                        // Push to the exploration queue, depending on the inner schema format
217                        Schema::Map(MapSchema { types: sc, .. })
218                        | Schema::Array(ArraySchema { items: sc, .. }) => match sc.as_ref() {
219                            Schema::Fixed { .. }
220                            | Schema::Enum { .. }
221                            | Schema::Record { .. }
222                            | Schema::Map(..)
223                            | Schema::Array(..)
224                            | Schema::Union(..) => {
225                                q.push_back(sc);
226                                push_unique(&mut deps, s.clone());
227                            }
228                            _ => (),
229                        },
230                        Schema::Union(union) => {
231                            if (union.is_nullable() && union.variants().len() > 2)
232                                || (!union.is_nullable() && !union.variants().is_empty())
233                            {
234                                push_unique(&mut deps, sr.clone());
235                            }
236
237                            union.variants().iter().for_each(|sc| match sc {
238                                Schema::Fixed { .. }
239                                | Schema::Enum { .. }
240                                | Schema::Record { .. }
241                                | Schema::Map(..)
242                                | Schema::Array(..)
243                                | Schema::Union(..) => {
244                                    q.push_back(sc);
245                                    push_unique(&mut deps, sc.clone());
246                                }
247
248                                _ => (),
249                            });
250                        }
251                        _ => (),
252                    };
253                    i += 1;
254                }
255            }
256
257            // Depending on the inner schema type ...
258            Schema::Map(MapSchema { types: sc, .. })
259            | Schema::Array(ArraySchema { items: sc, .. }) => match sc.as_ref() {
260                // ... Needs further checks, push to the exploration queue
261                Schema::Fixed { .. }
262                | Schema::Enum { .. }
263                | Schema::Record { .. }
264                | Schema::Map(..)
265                | Schema::Array(..)
266                | Schema::Union(..) => {
267                    q.push_back(sc.as_ref());
268                    push_unique(&mut deps, s.clone());
269                }
270                // ... Not nested, can be pushed to the result stack
271                _ => push_unique(&mut deps, s.clone()),
272            },
273
274            Schema::Union(union) => {
275                if (union.is_nullable() && union.variants().len() > 2)
276                    || (!union.is_nullable() && union.variants().len() > 1)
277                {
278                    push_unique(&mut deps, s.clone());
279                }
280
281                union.variants().iter().for_each(|sc| match sc {
282                    // ... Needs further checks, push to the exploration queue
283                    Schema::Fixed { .. }
284                    | Schema::Enum { .. }
285                    | Schema::Record { .. }
286                    | Schema::Map(..)
287                    | Schema::Array(..)
288                    | Schema::Union(..) => {
289                        q.push_back(sc);
290                        push_unique(&mut deps, s.clone());
291                    }
292                    // ... Not nested, can be pushed to the result stack
293                    _ => push_unique(&mut deps, s.clone()),
294                });
295            }
296
297            // Ignore all other schema formats
298            _ => (),
299        }
300    }
301
302    deps
303}
304
305/// A builder class to customize `Generator`.
306pub struct GeneratorBuilder {
307    precision: usize,
308    nullable: bool,
309    use_avro_rs_unions: bool,
310    use_chrono_dates: bool,
311    derive_builders: bool,
312    impl_schemas: ImplementAvroSchema,
313    extra_derives: Vec<String>,
314    field_overrides: HashMap<Name, Vec<FieldOverride>>,
315}
316
317impl Default for GeneratorBuilder {
318    fn default() -> Self {
319        Self {
320            precision: 3,
321            nullable: false,
322            use_avro_rs_unions: false,
323            use_chrono_dates: false,
324            derive_builders: false,
325            impl_schemas: ImplementAvroSchema::None,
326            extra_derives: vec![],
327            field_overrides: HashMap::new(),
328        }
329    }
330}
331
332#[derive(PartialEq, Debug, Clone, Copy, Default)]
333#[cfg_attr(feature = "build-cli", derive(clap::ValueEnum))]
334/// How to implement [`AvroSchema`][avsc].
335///
336/// [avsc]: apache_avro::schema::AvroSchema
337pub enum ImplementAvroSchema {
338    /// Use the [`AvroSchema`][derive] derive.
339    ///
340    /// This might result in a slightly different schema, as names can have different
341    /// capitalisation.
342    ///
343    /// [derive]: derive@apache_avro::AvroSchema
344    Derive,
345
346    /// Copy the schema used at build time.
347    ///
348    /// This will use the [canonical form](Schema::canonical_form) to create an exact
349    /// (canonical) match. Implementations generated by this functionality won't use the
350    /// [`AvroSchemaComponent`][avsc-compo] implementation of subtypes.
351    ///
352    /// [avsc-compo]: apache_avro::schema::derive::AvroSchemaComponent
353    CopyBuildSchema,
354
355    /// Do not implement or derive [`AvroSchema`][avsc].
356    ///
357    /// [avsc]: apache_avro::schema::AvroSchema
358    #[default]
359    None,
360}
361
362impl std::fmt::Display for ImplementAvroSchema {
363    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
364        write!(f, "{self:?}")
365    }
366}
367
368/// Override (part of) the generated code for a [`Schema`] field.
369///
370/// Currently only possible for Record schemas.
371///
372/// When changing the type of the field, `implements_eq` must also be changed.
373/// `serde_with` and `default` might also need to be changed. If this is not done
374/// it will result in a compiler error.
375#[derive(Debug, Clone)]
376
377pub struct FieldOverride {
378    /// Name of the schema the field is in.
379    pub schema: Name,
380    /// Name of the field as in the schema.
381    pub field: String,
382    /// Change the documentation of the field.
383    pub docstring: Option<String>,
384    /// Change the type of the field.
385    ///
386    /// This type *must* implement [`Debug`], [`PartialEq`], and [`Clone`].
387    /// If extra derives are configured (including `Builder` and [`AvroSchema`](apache_avro::AvroSchema))
388    /// then the type *must* also implement these.
389    pub type_name: Option<String>,
390    /// Does the type implement [`Eq`].
391    ///
392    /// This *must* be set if the type is changed.
393    pub implements_eq: Option<bool>,
394    /// Module name to use for `#[serde(with = ...)]`.
395    ///
396    /// This *must* be set if the type was changed and the type does not implement `Serialize` or `Deserialize`.
397    pub serde_with: Option<String>,
398    /// Default value for this field, can be a function call that generates the default value.
399    ///
400    /// This *must* be set if the field is nullable and the type was changed and the outer type is not an [`Option`].
401    pub default: Option<String>,
402}
403
404impl GeneratorBuilder {
405    /// Creates a new [`GeneratorBuilder`](GeneratorBuilder).
406    pub fn new() -> GeneratorBuilder {
407        GeneratorBuilder::default()
408    }
409
410    /// Sets the precision for default values of f32/f64 fields.
411    pub fn precision(mut self, precision: usize) -> GeneratorBuilder {
412        self.precision = precision;
413        self
414    }
415
416    /// Puts default value when deserializing `null` field.
417    ///
418    /// Doesn't apply to union fields ["null", "Foo"], which are `Option<Foo>`.
419    pub fn nullable(mut self, nullable: bool) -> GeneratorBuilder {
420        self.nullable = nullable;
421        self
422    }
423
424    /// Adds support for deserializing union types from the `apache-avro` crate.
425    ///
426    /// Only necessary for unions of 3 or more types or 2-type unions without "null".
427    /// Note that only int, long, float, double, boolean and bytes values are currently supported.
428    pub fn use_avro_rs_unions(mut self, use_avro_rs_unions: bool) -> GeneratorBuilder {
429        self.use_avro_rs_unions = use_avro_rs_unions;
430        self
431    }
432
433    /// Use chrono::NaiveDateTime for date/timestamps logical types
434    pub fn use_chrono_dates(mut self, use_chrono_dates: bool) -> GeneratorBuilder {
435        self.use_chrono_dates = use_chrono_dates;
436        self
437    }
438
439    /// Adds support to derive builders using the `rust-derive-builder` crate.
440    ///
441    /// Applies to record structs.
442    pub fn derive_builders(mut self, derive_builders: bool) -> GeneratorBuilder {
443        self.derive_builders = derive_builders;
444        self
445    }
446
447    /// Add an implementation of [`AvroSchema`][avsc].
448    ///
449    /// This implementation can either use a derive or copy the schema used to generate the type.
450    /// See [`ImplementAvroSchema`] for more information.
451    ///
452    /// Applies to record structs.
453    ///
454    /// [avsc]: apache_avro::schema::AvroSchema
455    pub fn implement_avro_schema(mut self, impl_schemas: ImplementAvroSchema) -> GeneratorBuilder {
456        self.impl_schemas = impl_schemas;
457        self
458    }
459
460    /// Adds support to derive custom macros.
461    ///
462    /// Applies to record structs.
463    pub fn extra_derives(mut self, extra_derives: Vec<String>) -> GeneratorBuilder {
464        self.extra_derives = extra_derives;
465        self
466    }
467
468    /// Override (part of) the code generated for a field.
469    ///
470    /// Applies to record structs.
471    pub fn override_fields(mut self, overrides: Vec<FieldOverride>) -> GeneratorBuilder {
472        for over in overrides {
473            self.field_overrides
474                .entry(over.schema.clone())
475                .or_default()
476                .push(over);
477        }
478        self
479    }
480
481    /// Override (part of) the code generated for a field.
482    ///
483    /// Applies to record structs.
484    pub fn override_field(mut self, over: FieldOverride) -> GeneratorBuilder {
485        self.field_overrides
486            .entry(over.schema.clone())
487            .or_default()
488            .push(over);
489        self
490    }
491
492    /// Create a [`Generator`](Generator) with the builder parameters.
493    pub fn build(self) -> Result<Generator> {
494        let mut templater = Templater::new()?;
495        templater.precision = self.precision;
496        templater.nullable = self.nullable;
497        templater.use_avro_rs_unions = self.use_avro_rs_unions;
498        templater.use_chrono_dates = self.use_chrono_dates;
499        templater.derive_builders = self.derive_builders;
500        templater.derive_schemas = self.impl_schemas == ImplementAvroSchema::Derive;
501        templater.impl_schemas = self.impl_schemas == ImplementAvroSchema::CopyBuildSchema;
502        templater.extra_derives = self.extra_derives;
503        templater.field_overrides = self.field_overrides;
504        Ok(Generator { templater })
505    }
506}
507
508#[cfg(test)]
509mod tests {
510    use apache_avro::schema::{EnumSchema, Name};
511    use pretty_assertions::assert_eq;
512
513    use super::*;
514
515    #[test]
516    fn deps() {
517        let raw_schema = r#"
518{
519  "type": "record",
520  "name": "User",
521  "fields": [
522    {"name": "name", "type": "string", "default": "unknown"},
523    {"name": "address",
524     "type": {
525       "type": "record",
526       "name": "Address",
527       "fields": [
528         {"name": "city", "type": "string", "default": "unknown"},
529         {"name": "country",
530          "type": {"type": "enum", "name": "Country", "symbols": ["FR", "JP"]}
531         }
532       ]
533     }
534    }
535  ]
536}
537"#;
538
539        let schema = Schema::parse_str(raw_schema).unwrap();
540        let mut deps = deps_stack(&schema, vec![]);
541
542        let s = deps.pop().unwrap();
543        assert!(
544            matches!(s, Schema::Enum(EnumSchema{ name: Name { ref name, ..}, ..}) if name == "Country")
545        );
546
547        let s = deps.pop().unwrap();
548        assert!(
549            matches!(s, Schema::Record(RecordSchema{ name: Name { ref name, ..}, ..}) if name == "Address")
550        );
551
552        let s = deps.pop().unwrap();
553        assert!(
554            matches!(s, Schema::Record(RecordSchema{ name: Name { ref name, ..}, ..}) if name == "User")
555        );
556
557        let s = deps.pop();
558        assert!(s.is_none());
559    }
560
561    #[test]
562    fn cross_deps() -> std::result::Result<(), Box<dyn std::error::Error>> {
563        use std::fs::File;
564        use std::io::Write;
565        use tempfile::tempdir;
566
567        let dir = tempdir()?;
568
569        let mut schema_a_file = File::create(dir.path().join("schema_a.avsc"))?;
570        let schema_a_str = r#"
571{
572  "name": "A",
573  "type": "record",
574  "fields": [ {"name": "field_one", "type": "float"} ]
575}
576"#;
577        schema_a_file.write_all(schema_a_str.as_bytes())?;
578
579        let mut schema_b_file = File::create(dir.path().join("schema_b.avsc"))?;
580        let schema_b_str = r#"
581{
582  "name": "B",
583  "type": "record",
584  "fields": [ {"name": "field_one", "type": "A"} ]
585}
586"#;
587        schema_b_file.write_all(schema_b_str.as_bytes())?;
588
589        let expected = r#"
590#[derive(Debug, PartialEq, Clone, serde::Deserialize, serde::Serialize)]
591pub struct B {
592    pub field_one: A,
593}
594
595#[derive(Debug, PartialEq, Clone, serde::Deserialize, serde::Serialize)]
596pub struct A {
597    pub field_one: f32,
598}
599"#;
600
601        let pattern = format!("{}/*.avsc", dir.path().display());
602        let source = Source::GlobPattern(pattern.as_str());
603        let g = Generator::new()?;
604        let mut buf = vec![];
605        g.generate(&source, &mut buf)?;
606        let res = String::from_utf8(buf)?;
607
608        assert_eq!(expected, res);
609
610        drop(schema_a_file);
611        drop(schema_b_file);
612        dir.close()?;
613        Ok(())
614    }
615}