1use std::collections::{BTreeMap, BTreeSet};
4use std::path::PathBuf;
5
6use utoipa::openapi::path::{Operation, PathItem};
7use utoipa::openapi::{Components, OpenApi, Ref, RefOr};
8
9use super::{Fragment, OpenApiSplitter, SplitResult};
10
11fn iter_operations(path_item: &PathItem) -> impl Iterator<Item = &Operation> {
13 [
14 path_item.get.as_ref(),
15 path_item.put.as_ref(),
16 path_item.post.as_ref(),
17 path_item.delete.as_ref(),
18 path_item.options.as_ref(),
19 path_item.head.as_ref(),
20 path_item.patch.as_ref(),
21 path_item.trace.as_ref(),
22 ]
23 .into_iter()
24 .flatten()
25}
26
27#[derive(Debug, Clone)]
50pub struct SplitSchemasByTag {
51 common_file: PathBuf,
53 schemas_dir: Option<PathBuf>,
55}
56
57impl SplitSchemasByTag {
58 pub fn new(common_file: impl Into<PathBuf>) -> Self {
62 Self {
63 common_file: common_file.into(),
64 schemas_dir: None,
65 }
66 }
67
68 pub fn with_schemas_dir(mut self, dir: impl Into<PathBuf>) -> Self {
72 self.schemas_dir = Some(dir.into());
73 self
74 }
75
76 fn analyze_schema_usage(&self, spec: &OpenApi) -> BTreeMap<String, BTreeSet<String>> {
78 let mut schema_to_tags: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
79
80 for (_path, path_item) in spec.paths.paths.iter() {
82 for operation in iter_operations(path_item) {
83 let tags = operation.tags.clone().unwrap_or_default();
84 if tags.is_empty() {
85 continue;
86 }
87
88 if let Some(ref request_body) = operation.request_body {
90 for content in request_body.content.values() {
91 if let Some(ref schema) = content.schema {
92 self.collect_schema_refs(schema, &tags, &mut schema_to_tags);
93 }
94 }
95 }
96
97 for response in operation.responses.responses.values() {
99 if let RefOr::T(resp) = response {
100 for content in resp.content.values() {
101 if let Some(ref schema) = content.schema {
102 self.collect_schema_refs(schema, &tags, &mut schema_to_tags);
103 }
104 }
105 }
106 }
107
108 if let Some(ref parameters) = operation.parameters {
110 for param in parameters {
111 if let Some(ref schema) = param.schema {
112 self.collect_schema_refs(schema, &tags, &mut schema_to_tags);
113 }
114 }
115 }
116 }
117 }
118
119 schema_to_tags
120 }
121
122 fn collect_schema_refs(
124 &self,
125 schema: &RefOr<utoipa::openapi::Schema>,
126 tags: &[String],
127 schema_to_tags: &mut BTreeMap<String, BTreeSet<String>>,
128 ) {
129 match schema {
130 RefOr::Ref(r) => {
131 if let Some(name) = extract_schema_name(&r.ref_location) {
132 let entry = schema_to_tags.entry(name).or_default();
133 for tag in tags {
134 entry.insert(tag.clone());
135 }
136 }
137 }
138 RefOr::T(_) => {
139 }
141 }
142 }
143
144 fn target_file_for_schema(&self, _schema_name: &str, tags: &BTreeSet<String>) -> PathBuf {
146 let base_dir = self.schemas_dir.clone().unwrap_or_default();
147
148 if tags.len() == 1 {
149 let tag = tags.iter().next().expect("checked len == 1");
151 base_dir.join(format!("{tag}.yaml"))
152 } else {
153 if self.schemas_dir.is_some() {
155 base_dir.join(&self.common_file)
156 } else {
157 self.common_file.clone()
158 }
159 }
160 }
161
162 fn create_external_ref(file_path: &std::path::Path, schema_name: &str) -> String {
164 format!(
165 "{}#/components/schemas/{}",
166 file_path.display(),
167 schema_name
168 )
169 }
170}
171
172impl OpenApiSplitter for SplitSchemasByTag {
173 type Fragment = Components;
174
175 fn split(&self, mut spec: OpenApi) -> SplitResult<Self::Fragment> {
176 let schema_to_tags = self.analyze_schema_usage(&spec);
177
178 let mut file_to_schemas: BTreeMap<PathBuf, BTreeSet<String>> = BTreeMap::new();
180 for (schema_name, tags) in &schema_to_tags {
181 let target = self.target_file_for_schema(schema_name, tags);
182 file_to_schemas
183 .entry(target)
184 .or_default()
185 .insert(schema_name.clone());
186 }
187
188 if file_to_schemas.len() <= 1 {
190 return SplitResult::new(spec);
191 }
192
193 let mut result = SplitResult::new(spec.clone());
194
195 let original_components = spec.components.take().unwrap_or_default();
197 let mut remaining_schemas = original_components.schemas.clone();
198
199 for (file_path, schema_names) in &file_to_schemas {
200 let mut fragment_components = Components::new();
201
202 for schema_name in schema_names {
203 if let Some(schema) = remaining_schemas.remove(schema_name) {
204 fragment_components
205 .schemas
206 .insert(schema_name.clone(), schema);
207 }
208 }
209
210 if !fragment_components.schemas.is_empty() {
211 result.add_fragment(Fragment::new(file_path.clone(), fragment_components));
212 }
213 }
214
215 let mut new_components = Components::new();
217
218 for (file_path, schema_names) in &file_to_schemas {
220 for schema_name in schema_names {
221 let external_ref = Self::create_external_ref(file_path, schema_name);
222 new_components
223 .schemas
224 .insert(schema_name.clone(), RefOr::Ref(Ref::new(external_ref)));
225 }
226 }
227
228 for (name, schema) in remaining_schemas {
230 new_components.schemas.insert(name, schema);
231 }
232
233 new_components.security_schemes = original_components.security_schemes;
235 new_components.responses = original_components.responses;
236
237 result.main.components = Some(new_components);
238 result
239 }
240}
241
242#[derive(Clone)]
261pub struct ExtractSchemasByPredicate<F>
262where
263 F: Fn(&str) -> bool,
264{
265 target_file: PathBuf,
267 predicate: F,
269}
270
271impl<F> ExtractSchemasByPredicate<F>
272where
273 F: Fn(&str) -> bool,
274{
275 pub fn new(target_file: impl Into<PathBuf>, predicate: F) -> Self {
280 Self {
281 target_file: target_file.into(),
282 predicate,
283 }
284 }
285}
286
287impl<F> OpenApiSplitter for ExtractSchemasByPredicate<F>
288where
289 F: Fn(&str) -> bool,
290{
291 type Fragment = Components;
292
293 fn split(&self, mut spec: OpenApi) -> SplitResult<Self::Fragment> {
294 let Some(mut components) = spec.components.take() else {
295 return SplitResult::new(spec);
296 };
297
298 let schemas_to_extract: Vec<String> = components
300 .schemas
301 .keys()
302 .filter(|name| (self.predicate)(name))
303 .cloned()
304 .collect();
305
306 if schemas_to_extract.is_empty() {
308 spec.components = Some(components);
309 return SplitResult::new(spec);
310 }
311
312 let mut extracted = Components::new();
314 for name in &schemas_to_extract {
315 if let Some(schema) = components.schemas.remove(name) {
316 extracted.schemas.insert(name.clone(), schema);
317 }
318 }
319
320 for name in &schemas_to_extract {
322 let external_ref = format!(
323 "{}#/components/schemas/{}",
324 self.target_file.display(),
325 name
326 );
327 components
328 .schemas
329 .insert(name.clone(), RefOr::Ref(Ref::new(external_ref)));
330 }
331
332 spec.components = Some(components);
333
334 let mut result = SplitResult::new(spec);
335 result.add_fragment(Fragment::new(self.target_file.clone(), extracted));
336 result
337 }
338}
339
340fn extract_schema_name(ref_location: &str) -> Option<String> {
348 const SCHEMA_PREFIX: &str = "#/components/schemas/";
349 ref_location
350 .strip_prefix(SCHEMA_PREFIX)
351 .map(|s| s.to_string())
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use utoipa::openapi::path::OperationBuilder;
358 use utoipa::openapi::path::PathItemBuilder;
359 use utoipa::openapi::{ContentBuilder, ObjectBuilder, OpenApiBuilder, ResponseBuilder};
360
361 fn create_test_spec() -> OpenApi {
362 let user_schema = ObjectBuilder::new()
363 .property(
364 "id",
365 ObjectBuilder::new().schema_type(utoipa::openapi::Type::Integer),
366 )
367 .property(
368 "name",
369 ObjectBuilder::new().schema_type(utoipa::openapi::Type::String),
370 )
371 .build();
372
373 let error_schema = ObjectBuilder::new()
374 .property(
375 "code",
376 ObjectBuilder::new().schema_type(utoipa::openapi::Type::Integer),
377 )
378 .property(
379 "message",
380 ObjectBuilder::new().schema_type(utoipa::openapi::Type::String),
381 )
382 .build();
383
384 let order_schema = ObjectBuilder::new()
385 .property(
386 "id",
387 ObjectBuilder::new().schema_type(utoipa::openapi::Type::Integer),
388 )
389 .property(
390 "total",
391 ObjectBuilder::new().schema_type(utoipa::openapi::Type::Number),
392 )
393 .build();
394
395 let mut components = Components::new();
396 components
397 .schemas
398 .insert("User".to_string(), RefOr::T(user_schema.into()));
399 components
400 .schemas
401 .insert("Error".to_string(), RefOr::T(error_schema.into()));
402 components
403 .schemas
404 .insert("Order".to_string(), RefOr::T(order_schema.into()));
405
406 let get_users = OperationBuilder::new()
408 .tags(Some(vec!["users".to_string()]))
409 .response(
410 "200",
411 ResponseBuilder::new()
412 .content(
413 "application/json",
414 ContentBuilder::new()
415 .schema(Some(RefOr::Ref(Ref::new("#/components/schemas/User"))))
416 .build(),
417 )
418 .build(),
419 )
420 .build();
421
422 let get_orders = OperationBuilder::new()
423 .tags(Some(vec!["orders".to_string()]))
424 .response(
425 "200",
426 ResponseBuilder::new()
427 .content(
428 "application/json",
429 ContentBuilder::new()
430 .schema(Some(RefOr::Ref(Ref::new("#/components/schemas/Order"))))
431 .build(),
432 )
433 .build(),
434 )
435 .response(
436 "400",
437 ResponseBuilder::new()
438 .content(
439 "application/json",
440 ContentBuilder::new()
441 .schema(Some(RefOr::Ref(Ref::new("#/components/schemas/Error"))))
442 .build(),
443 )
444 .build(),
445 )
446 .build();
447
448 let get_user_orders = OperationBuilder::new()
449 .tags(Some(vec!["users".to_string(), "orders".to_string()]))
450 .response(
451 "400",
452 ResponseBuilder::new()
453 .content(
454 "application/json",
455 ContentBuilder::new()
456 .schema(Some(RefOr::Ref(Ref::new("#/components/schemas/Error"))))
457 .build(),
458 )
459 .build(),
460 )
461 .build();
462
463 let mut paths = utoipa::openapi::Paths::new();
464 paths.paths.insert(
465 "/users".to_string(),
466 PathItemBuilder::new()
467 .operation(utoipa::openapi::HttpMethod::Get, get_users)
468 .build(),
469 );
470 paths.paths.insert(
471 "/orders".to_string(),
472 PathItemBuilder::new()
473 .operation(utoipa::openapi::HttpMethod::Get, get_orders)
474 .build(),
475 );
476 paths.paths.insert(
477 "/users/{id}/orders".to_string(),
478 PathItemBuilder::new()
479 .operation(utoipa::openapi::HttpMethod::Get, get_user_orders)
480 .build(),
481 );
482
483 OpenApiBuilder::new()
484 .paths(paths)
485 .components(Some(components))
486 .build()
487 }
488
489 #[test]
490 fn should_extract_schema_name() {
491 assert_eq!(
492 extract_schema_name("#/components/schemas/User"),
493 Some("User".to_string())
494 );
495 assert_eq!(
496 extract_schema_name("#/components/schemas/MyError"),
497 Some("MyError".to_string())
498 );
499 assert_eq!(extract_schema_name("#/components/responses/Error"), None);
500 assert_eq!(extract_schema_name("User"), None);
501 }
502
503 #[test]
504 fn should_split_by_predicate() {
505 let spec = create_test_spec();
506
507 let splitter = ExtractSchemasByPredicate::new("errors.yaml", |name| name.contains("Error"));
508 let result = splitter.split(spec);
509
510 assert_eq!(result.fragment_count(), 1);
511 let fragment = &result.fragments[0];
512 assert_eq!(fragment.path, PathBuf::from("errors.yaml"));
513 assert!(fragment.content.schemas.contains_key("Error"));
514 assert!(!fragment.content.schemas.contains_key("User"));
515 assert!(!fragment.content.schemas.contains_key("Order"));
516
517 let main_components = result
519 .main
520 .components
521 .as_ref()
522 .expect("should have components");
523 match main_components.schemas.get("Error") {
524 Some(RefOr::Ref(r)) => {
525 assert!(r.ref_location.contains("errors.yaml"));
526 }
527 _ => panic!("Expected external reference for Error"),
528 }
529 }
530
531 #[test]
532 fn should_not_split_when_predicate_matches_nothing() {
533 let spec = create_test_spec();
534
535 let splitter =
536 ExtractSchemasByPredicate::new("nothing.yaml", |name| name.contains("NonExistent"));
537 let result = splitter.split(spec);
538
539 assert!(result.is_unsplit());
540 }
541
542 #[test]
543 fn should_analyze_schema_usage() {
544 let spec = create_test_spec();
545 let splitter = SplitSchemasByTag::new("common.yaml");
546
547 let usage = splitter.analyze_schema_usage(&spec);
548
549 assert!(
551 usage
552 .get("User")
553 .map(|t| t.contains("users"))
554 .unwrap_or(false)
555 );
556
557 assert!(
559 usage
560 .get("Order")
561 .map(|t| t.contains("orders"))
562 .unwrap_or(false)
563 );
564
565 let error_tags = usage.get("Error").expect("Error should be tracked");
567 assert!(error_tags.contains("orders"));
568 }
569}