clawspec_core/split/
strategies.rs

1//! Built-in splitting strategies for OpenAPI specifications.
2
3use std::collections::{BTreeMap, BTreeSet};
4use std::path::PathBuf;
5
6use utoipa::openapi::path::{Operation, PathItem};
7use utoipa::openapi::{Components, OpenApi, Ref, RefOr};
8
9use super::{Fragment, OpenApiSplitter, SplitResult};
10
11/// Helper to iterate over all operations in a PathItem.
12fn iter_operations(path_item: &PathItem) -> impl Iterator<Item = &Operation> {
13    [
14        path_item.get.as_ref(),
15        path_item.put.as_ref(),
16        path_item.post.as_ref(),
17        path_item.delete.as_ref(),
18        path_item.options.as_ref(),
19        path_item.head.as_ref(),
20        path_item.patch.as_ref(),
21        path_item.trace.as_ref(),
22    ]
23    .into_iter()
24    .flatten()
25}
26
27/// Splits schemas based on which tags use them.
28///
29/// This splitter analyzes which schemas are referenced by operations with specific tags
30/// and organizes them into separate files:
31///
32/// - Schemas used by only one tag go into a file named after that tag
33/// - Schemas used by multiple tags go into a common file
34///
35/// # Example
36///
37/// ```rust,ignore
38/// use clawspec_core::split::{OpenApiSplitter, SplitSchemasByTag};
39///
40/// let splitter = SplitSchemasByTag::new("common-types.yaml");
41/// let result = splitter.split(spec);
42///
43/// // Result might contain:
44/// // - main openapi.yaml with $refs to external files
45/// // - users.yaml with User, CreateUser schemas
46/// // - orders.yaml with Order, OrderItem schemas
47/// // - common-types.yaml with Error, Pagination schemas used by both
48/// ```
49#[derive(Debug, Clone)]
50pub struct SplitSchemasByTag {
51    /// Path for schemas used by multiple tags.
52    common_file: PathBuf,
53    /// Optional directory prefix for tag-specific files.
54    schemas_dir: Option<PathBuf>,
55}
56
57impl SplitSchemasByTag {
58    /// Creates a new splitter with the specified common file path.
59    ///
60    /// Tag-specific files will be created in the same directory as the common file.
61    pub fn new(common_file: impl Into<PathBuf>) -> Self {
62        Self {
63            common_file: common_file.into(),
64            schemas_dir: None,
65        }
66    }
67
68    /// Sets the directory for schema files.
69    ///
70    /// Both tag-specific and common files will be placed in this directory.
71    pub fn with_schemas_dir(mut self, dir: impl Into<PathBuf>) -> Self {
72        self.schemas_dir = Some(dir.into());
73        self
74    }
75
76    /// Analyzes which tags reference which schemas.
77    fn analyze_schema_usage(&self, spec: &OpenApi) -> BTreeMap<String, BTreeSet<String>> {
78        let mut schema_to_tags: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
79
80        // Iterate through all paths and operations
81        for (_path, path_item) in spec.paths.paths.iter() {
82            for operation in iter_operations(path_item) {
83                let tags = operation.tags.clone().unwrap_or_default();
84                if tags.is_empty() {
85                    continue;
86                }
87
88                // Collect schema references from request body
89                if let Some(ref request_body) = operation.request_body {
90                    for content in request_body.content.values() {
91                        if let Some(ref schema) = content.schema {
92                            self.collect_schema_refs(schema, &tags, &mut schema_to_tags);
93                        }
94                    }
95                }
96
97                // Collect schema references from responses
98                for response in operation.responses.responses.values() {
99                    if let RefOr::T(resp) = response {
100                        for content in resp.content.values() {
101                            if let Some(ref schema) = content.schema {
102                                self.collect_schema_refs(schema, &tags, &mut schema_to_tags);
103                            }
104                        }
105                    }
106                }
107
108                // Collect schema references from parameters
109                if let Some(ref parameters) = operation.parameters {
110                    for param in parameters {
111                        if let Some(ref schema) = param.schema {
112                            self.collect_schema_refs(schema, &tags, &mut schema_to_tags);
113                        }
114                    }
115                }
116            }
117        }
118
119        schema_to_tags
120    }
121
122    /// Collects schema references from a schema, adding tag associations.
123    fn collect_schema_refs(
124        &self,
125        schema: &RefOr<utoipa::openapi::Schema>,
126        tags: &[String],
127        schema_to_tags: &mut BTreeMap<String, BTreeSet<String>>,
128    ) {
129        match schema {
130            RefOr::Ref(r) => {
131                if let Some(name) = extract_schema_name(&r.ref_location) {
132                    let entry = schema_to_tags.entry(name).or_default();
133                    for tag in tags {
134                        entry.insert(tag.clone());
135                    }
136                }
137            }
138            RefOr::T(_) => {
139                // Inline schema, no reference to extract
140            }
141        }
142    }
143
144    /// Determines the target file for a schema based on its tag usage.
145    fn target_file_for_schema(&self, _schema_name: &str, tags: &BTreeSet<String>) -> PathBuf {
146        let base_dir = self.schemas_dir.clone().unwrap_or_default();
147
148        if tags.len() == 1 {
149            // Schema used by only one tag - put in tag-specific file
150            let tag = tags.iter().next().expect("checked len == 1");
151            base_dir.join(format!("{tag}.yaml"))
152        } else {
153            // Schema used by multiple tags or no tags - put in common file
154            if self.schemas_dir.is_some() {
155                base_dir.join(&self.common_file)
156            } else {
157                self.common_file.clone()
158            }
159        }
160    }
161
162    /// Creates external reference string for a schema in a file.
163    fn create_external_ref(file_path: &std::path::Path, schema_name: &str) -> String {
164        format!(
165            "{}#/components/schemas/{}",
166            file_path.display(),
167            schema_name
168        )
169    }
170}
171
172impl OpenApiSplitter for SplitSchemasByTag {
173    type Fragment = Components;
174
175    fn split(&self, mut spec: OpenApi) -> SplitResult<Self::Fragment> {
176        let schema_to_tags = self.analyze_schema_usage(&spec);
177
178        // Group schemas by their target file
179        let mut file_to_schemas: BTreeMap<PathBuf, BTreeSet<String>> = BTreeMap::new();
180        for (schema_name, tags) in &schema_to_tags {
181            let target = self.target_file_for_schema(schema_name, tags);
182            file_to_schemas
183                .entry(target)
184                .or_default()
185                .insert(schema_name.clone());
186        }
187
188        // If all schemas go to one file or no schemas, no splitting needed
189        if file_to_schemas.len() <= 1 {
190            return SplitResult::new(spec);
191        }
192
193        let mut result = SplitResult::new(spec.clone());
194
195        // Extract schemas and create fragments
196        let original_components = spec.components.take().unwrap_or_default();
197        let mut remaining_schemas = original_components.schemas.clone();
198
199        for (file_path, schema_names) in &file_to_schemas {
200            let mut fragment_components = Components::new();
201
202            for schema_name in schema_names {
203                if let Some(schema) = remaining_schemas.remove(schema_name) {
204                    fragment_components
205                        .schemas
206                        .insert(schema_name.clone(), schema);
207                }
208            }
209
210            if !fragment_components.schemas.is_empty() {
211                result.add_fragment(Fragment::new(file_path.clone(), fragment_components));
212            }
213        }
214
215        // Update the main spec's schema references to point to external files
216        let mut new_components = Components::new();
217
218        // Add external references for extracted schemas
219        for (file_path, schema_names) in &file_to_schemas {
220            for schema_name in schema_names {
221                let external_ref = Self::create_external_ref(file_path, schema_name);
222                new_components
223                    .schemas
224                    .insert(schema_name.clone(), RefOr::Ref(Ref::new(external_ref)));
225            }
226        }
227
228        // Keep any remaining schemas that weren't extracted
229        for (name, schema) in remaining_schemas {
230            new_components.schemas.insert(name, schema);
231        }
232
233        // Preserve security schemes and responses
234        new_components.security_schemes = original_components.security_schemes;
235        new_components.responses = original_components.responses;
236
237        result.main.components = Some(new_components);
238        result
239    }
240}
241
242/// Extracts schemas matching a predicate into a separate file.
243///
244/// This splitter allows fine-grained control over which schemas are extracted
245/// by providing a predicate function that determines whether a schema should
246/// be moved to the external file.
247///
248/// # Example
249///
250/// ```rust,ignore
251/// use clawspec_core::split::{OpenApiSplitter, ExtractSchemasByPredicate};
252///
253/// // Extract all error-related schemas
254/// let splitter = ExtractSchemasByPredicate::new(
255///     "errors.yaml",
256///     |name| name.contains("Error") || name.contains("Exception"),
257/// );
258/// let result = splitter.split(spec);
259/// ```
260#[derive(Clone)]
261pub struct ExtractSchemasByPredicate<F>
262where
263    F: Fn(&str) -> bool,
264{
265    /// Path for the extracted schemas file.
266    target_file: PathBuf,
267    /// Predicate function that returns true for schemas to extract.
268    predicate: F,
269}
270
271impl<F> ExtractSchemasByPredicate<F>
272where
273    F: Fn(&str) -> bool,
274{
275    /// Creates a new splitter with the specified target file and predicate.
276    ///
277    /// The predicate receives the schema name and should return `true`
278    /// if the schema should be extracted to the target file.
279    pub fn new(target_file: impl Into<PathBuf>, predicate: F) -> Self {
280        Self {
281            target_file: target_file.into(),
282            predicate,
283        }
284    }
285}
286
287impl<F> OpenApiSplitter for ExtractSchemasByPredicate<F>
288where
289    F: Fn(&str) -> bool,
290{
291    type Fragment = Components;
292
293    fn split(&self, mut spec: OpenApi) -> SplitResult<Self::Fragment> {
294        let Some(mut components) = spec.components.take() else {
295            return SplitResult::new(spec);
296        };
297
298        // Find schemas to extract (collect names first to avoid borrowing issues)
299        let schemas_to_extract: Vec<String> = components
300            .schemas
301            .keys()
302            .filter(|name| (self.predicate)(name))
303            .cloned()
304            .collect();
305
306        // If nothing to extract, return unchanged
307        if schemas_to_extract.is_empty() {
308            spec.components = Some(components);
309            return SplitResult::new(spec);
310        }
311
312        // Extract matching schemas
313        let mut extracted = Components::new();
314        for name in &schemas_to_extract {
315            if let Some(schema) = components.schemas.remove(name) {
316                extracted.schemas.insert(name.clone(), schema);
317            }
318        }
319
320        // Create external references for extracted schemas
321        for name in &schemas_to_extract {
322            let external_ref = format!(
323                "{}#/components/schemas/{}",
324                self.target_file.display(),
325                name
326            );
327            components
328                .schemas
329                .insert(name.clone(), RefOr::Ref(Ref::new(external_ref)));
330        }
331
332        spec.components = Some(components);
333
334        let mut result = SplitResult::new(spec);
335        result.add_fragment(Fragment::new(self.target_file.clone(), extracted));
336        result
337    }
338}
339
340/// Extracts the schema name from a $ref string.
341///
342/// # Example
343///
344/// ```rust,ignore
345/// assert_eq!(extract_schema_name("#/components/schemas/User"), Some("User".to_string()));
346/// ```
347fn extract_schema_name(ref_location: &str) -> Option<String> {
348    const SCHEMA_PREFIX: &str = "#/components/schemas/";
349    ref_location
350        .strip_prefix(SCHEMA_PREFIX)
351        .map(|s| s.to_string())
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use utoipa::openapi::path::OperationBuilder;
358    use utoipa::openapi::path::PathItemBuilder;
359    use utoipa::openapi::{ContentBuilder, ObjectBuilder, OpenApiBuilder, ResponseBuilder};
360
361    fn create_test_spec() -> OpenApi {
362        let user_schema = ObjectBuilder::new()
363            .property(
364                "id",
365                ObjectBuilder::new().schema_type(utoipa::openapi::Type::Integer),
366            )
367            .property(
368                "name",
369                ObjectBuilder::new().schema_type(utoipa::openapi::Type::String),
370            )
371            .build();
372
373        let error_schema = ObjectBuilder::new()
374            .property(
375                "code",
376                ObjectBuilder::new().schema_type(utoipa::openapi::Type::Integer),
377            )
378            .property(
379                "message",
380                ObjectBuilder::new().schema_type(utoipa::openapi::Type::String),
381            )
382            .build();
383
384        let order_schema = ObjectBuilder::new()
385            .property(
386                "id",
387                ObjectBuilder::new().schema_type(utoipa::openapi::Type::Integer),
388            )
389            .property(
390                "total",
391                ObjectBuilder::new().schema_type(utoipa::openapi::Type::Number),
392            )
393            .build();
394
395        let mut components = Components::new();
396        components
397            .schemas
398            .insert("User".to_string(), RefOr::T(user_schema.into()));
399        components
400            .schemas
401            .insert("Error".to_string(), RefOr::T(error_schema.into()));
402        components
403            .schemas
404            .insert("Order".to_string(), RefOr::T(order_schema.into()));
405
406        // Create operations with tags
407        let get_users = OperationBuilder::new()
408            .tags(Some(vec!["users".to_string()]))
409            .response(
410                "200",
411                ResponseBuilder::new()
412                    .content(
413                        "application/json",
414                        ContentBuilder::new()
415                            .schema(Some(RefOr::Ref(Ref::new("#/components/schemas/User"))))
416                            .build(),
417                    )
418                    .build(),
419            )
420            .build();
421
422        let get_orders = OperationBuilder::new()
423            .tags(Some(vec!["orders".to_string()]))
424            .response(
425                "200",
426                ResponseBuilder::new()
427                    .content(
428                        "application/json",
429                        ContentBuilder::new()
430                            .schema(Some(RefOr::Ref(Ref::new("#/components/schemas/Order"))))
431                            .build(),
432                    )
433                    .build(),
434            )
435            .response(
436                "400",
437                ResponseBuilder::new()
438                    .content(
439                        "application/json",
440                        ContentBuilder::new()
441                            .schema(Some(RefOr::Ref(Ref::new("#/components/schemas/Error"))))
442                            .build(),
443                    )
444                    .build(),
445            )
446            .build();
447
448        let get_user_orders = OperationBuilder::new()
449            .tags(Some(vec!["users".to_string(), "orders".to_string()]))
450            .response(
451                "400",
452                ResponseBuilder::new()
453                    .content(
454                        "application/json",
455                        ContentBuilder::new()
456                            .schema(Some(RefOr::Ref(Ref::new("#/components/schemas/Error"))))
457                            .build(),
458                    )
459                    .build(),
460            )
461            .build();
462
463        let mut paths = utoipa::openapi::Paths::new();
464        paths.paths.insert(
465            "/users".to_string(),
466            PathItemBuilder::new()
467                .operation(utoipa::openapi::HttpMethod::Get, get_users)
468                .build(),
469        );
470        paths.paths.insert(
471            "/orders".to_string(),
472            PathItemBuilder::new()
473                .operation(utoipa::openapi::HttpMethod::Get, get_orders)
474                .build(),
475        );
476        paths.paths.insert(
477            "/users/{id}/orders".to_string(),
478            PathItemBuilder::new()
479                .operation(utoipa::openapi::HttpMethod::Get, get_user_orders)
480                .build(),
481        );
482
483        OpenApiBuilder::new()
484            .paths(paths)
485            .components(Some(components))
486            .build()
487    }
488
489    #[test]
490    fn should_extract_schema_name() {
491        assert_eq!(
492            extract_schema_name("#/components/schemas/User"),
493            Some("User".to_string())
494        );
495        assert_eq!(
496            extract_schema_name("#/components/schemas/MyError"),
497            Some("MyError".to_string())
498        );
499        assert_eq!(extract_schema_name("#/components/responses/Error"), None);
500        assert_eq!(extract_schema_name("User"), None);
501    }
502
503    #[test]
504    fn should_split_by_predicate() {
505        let spec = create_test_spec();
506
507        let splitter = ExtractSchemasByPredicate::new("errors.yaml", |name| name.contains("Error"));
508        let result = splitter.split(spec);
509
510        assert_eq!(result.fragment_count(), 1);
511        let fragment = &result.fragments[0];
512        assert_eq!(fragment.path, PathBuf::from("errors.yaml"));
513        assert!(fragment.content.schemas.contains_key("Error"));
514        assert!(!fragment.content.schemas.contains_key("User"));
515        assert!(!fragment.content.schemas.contains_key("Order"));
516
517        // Main spec should have external reference for Error
518        let main_components = result
519            .main
520            .components
521            .as_ref()
522            .expect("should have components");
523        match main_components.schemas.get("Error") {
524            Some(RefOr::Ref(r)) => {
525                assert!(r.ref_location.contains("errors.yaml"));
526            }
527            _ => panic!("Expected external reference for Error"),
528        }
529    }
530
531    #[test]
532    fn should_not_split_when_predicate_matches_nothing() {
533        let spec = create_test_spec();
534
535        let splitter =
536            ExtractSchemasByPredicate::new("nothing.yaml", |name| name.contains("NonExistent"));
537        let result = splitter.split(spec);
538
539        assert!(result.is_unsplit());
540    }
541
542    #[test]
543    fn should_analyze_schema_usage() {
544        let spec = create_test_spec();
545        let splitter = SplitSchemasByTag::new("common.yaml");
546
547        let usage = splitter.analyze_schema_usage(&spec);
548
549        // User is used by "users" tag
550        assert!(
551            usage
552                .get("User")
553                .map(|t| t.contains("users"))
554                .unwrap_or(false)
555        );
556
557        // Order is used by "orders" tag
558        assert!(
559            usage
560                .get("Order")
561                .map(|t| t.contains("orders"))
562                .unwrap_or(false)
563        );
564
565        // Error is used by both "users" and "orders" tags
566        let error_tags = usage.get("Error").expect("Error should be tracked");
567        assert!(error_tags.contains("orders"));
568    }
569}