1use std::collections::HashSet;
2use anyhow::{anyhow, Result};
3use serde::{Deserialize, Serialize};
4use crate::{OpenAPI, Parameter, RequestBody, Response, Schema};
5
6pub enum SchemaReference {
9 Schema {
10 schema: String,
11 },
12 Property {
13 schema: String,
14 property: String,
15 },
16}
17
18impl SchemaReference {
19 pub fn from_str(reference: &str) -> Self {
20 let mut ns = reference.rsplit('/');
21 let name = ns.next().unwrap();
22 match ns.next().unwrap() {
23 "schemas" => {
24 Self::Schema {
25 schema: name.to_string(),
26 }
27 }
28 "properties" => {
29 let schema_name = ns.next().unwrap();
30 Self::Property {
31 schema: schema_name.to_string(),
32 property: name.to_string(),
33 }
34 }
35 _ => panic!("Unknown reference: {}", reference),
36 }
37 }
38}
39
40
41impl std::fmt::Display for SchemaReference {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 match self {
44 SchemaReference::Schema { schema } => write!(f, "#/components/schemas/{}", schema),
45 SchemaReference::Property { schema, property } => write!(f, "#/components/schemas/{}/properties/{}", schema, property),
46 }
47 }
48}
49
50
51pub type ReferenceOr<T> = RefOr<T>;
53
54#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
55#[serde(untagged)]
56pub enum RefOr<T> {
57 Reference {
58 #[serde(rename = "$ref")]
59 reference: String,
60 },
61 Item(T),
62}
63
64impl<T> RefOr<T> {
65 pub fn ref_(r: &str) -> Self {
66 RefOr::Reference {
67 reference: r.to_owned(),
68 }
69 }
70 pub fn schema_ref(r: &str) -> Self {
71 RefOr::Reference {
72 reference: format!("#/components/schemas/{}", r),
73 }
74 }
75
76 pub fn boxed(self) -> Box<RefOr<T>> {
77 Box::new(self)
78 }
79
80 pub fn into_item(self) -> Option<T> {
96 match self {
97 RefOr::Reference { .. } => None,
98 RefOr::Item(i) => Some(i),
99 }
100 }
101
102 pub fn as_item(&self) -> Option<&T> {
118 match self {
119 RefOr::Reference { .. } => None,
120 RefOr::Item(i) => Some(i),
121 }
122 }
123
124 pub fn as_ref_str(&self) -> Option<&str> {
125 match self {
126 RefOr::Reference { reference } => Some(reference),
127 RefOr::Item(_) => None,
128 }
129 }
130
131 pub fn as_mut(&mut self) -> Option<&mut T> {
132 match self {
133 RefOr::Reference { .. } => None,
134 RefOr::Item(i) => Some(i),
135 }
136 }
137
138 pub fn to_mut(&mut self) -> &mut T {
139 self.as_mut().expect("Not an item")
140 }
141}
142
143fn resolve_helper<'a>(reference: &str, spec: &'a OpenAPI, seen: &mut HashSet<String>) -> &'a Schema {
144 if seen.contains(reference) {
145 panic!("Circular reference: {}", reference);
146 }
147 seen.insert(reference.to_string());
148 let reference = SchemaReference::from_str(&reference);
149 match &reference {
150 SchemaReference::Schema { ref schema } => {
151 let schema_ref = spec.schemas.get(schema)
152 .expect(&format!("Schema {} not found in OpenAPI spec.", schema));
153 match schema_ref {
156 RefOr::Reference { reference } => {
157 resolve_helper(&reference, spec, seen)
158 }
159 RefOr::Item(s) => s
160 }
161 }
162 SchemaReference::Property { schema: schema_name, property } => {
163 let schema = spec.schemas.get(schema_name)
164 .expect(&format!("Schema {} not found in OpenAPI spec.", schema_name))
165 .as_item()
166 .expect(&format!("The schema {} was used in a reference, but that schema is itself a reference to another schema.", schema_name));
167 let prop_schema = schema
168 .properties()
169 .get(property)
170 .expect(&format!("Schema {} does not have property {}.", schema_name, property));
171 prop_schema.resolve(spec)
172 }
173 }
174}
175
176impl RefOr<Schema> {
177 pub fn resolve<'a>(&'a self, spec: &'a OpenAPI) -> &'a Schema {
178 match self {
179 RefOr::Reference { reference } => {
180 resolve_helper(reference, spec, &mut HashSet::new())
181 }
182 RefOr::Item(schema) => schema,
183 }
184 }
185}
186
187impl<T> From<T> for RefOr<T> {
188 fn from(item: T) -> Self {
189 RefOr::Item(item)
190 }
191}
192
193impl RefOr<Parameter> {
194 pub fn resolve<'a>(&'a self, spec: &'a OpenAPI) -> Result<&'a Parameter> {
195 match self {
196 RefOr::Reference { reference } => {
197 let name = get_parameter_name(&reference)?;
198 spec.parameters.get(name)
199 .ok_or(anyhow!("{} not found in OpenAPI spec.", reference))?
200 .as_item()
201 .ok_or(anyhow!("{} is circular.", reference))
202 }
203 RefOr::Item(parameter) => Ok(parameter),
204 }
205 }
206}
207
208
209impl RefOr<Response> {
210 pub fn resolve<'a>(&'a self, spec: &'a OpenAPI) -> Result<&'a Response> {
211 match self {
212 RefOr::Reference { reference } => {
213 let name = get_response_name(&reference)?;
214 spec.responses.get(name)
215 .ok_or(anyhow!("{} not found in OpenAPI spec.", reference))?
216 .as_item()
217 .ok_or(anyhow!("{} is circular.", reference))
218 }
219 RefOr::Item(response) => Ok(response),
220 }
221 }
222}
223
224impl RefOr<RequestBody> {
225 pub fn resolve<'a>(&'a self, spec: &'a OpenAPI) -> Result<&'a RequestBody> {
226 match self {
227 RefOr::Reference { reference } => {
228 let name = get_request_body_name(&reference)?;
229 spec.request_bodies.get(name)
230 .ok_or(anyhow!("{} not found in OpenAPI spec.", reference))?
231 .as_item()
232 .ok_or(anyhow!("{} is circular.", reference))
233 }
234 RefOr::Item(request_body) => Ok(request_body),
235 }
236 }
237}
238
239impl<T: Default> Default for RefOr<T> {
240 fn default() -> Self {
241 RefOr::Item(T::default())
242 }
243}
244
245fn parse_reference<'a>(reference: &'a str, group: &str) -> Result<&'a str> {
246 let mut parts = reference.rsplitn(2, '/');
247 let name = parts.next();
248 name.filter(|_| matches!(parts.next(), Some(x) if format!("#/components/{group}") == x))
249 .ok_or(anyhow!("Invalid {} reference: {}", group, reference))
250}
251
252
253fn get_response_name(reference: &str) -> Result<&str> {
254 parse_reference(reference, "responses")
255}
256
257
258fn get_request_body_name(reference: &str) -> Result<&str> {
259 parse_reference(reference, "requestBodies")
260}
261
262fn get_parameter_name(reference: &str) -> Result<&str> {
263 parse_reference(reference, "parameters")
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_get_request_body_name() {
272 assert!(matches!(get_request_body_name("#/components/requestBodies/Foo"), Ok("Foo")));
273 assert!(get_request_body_name("#/components/schemas/Foo").is_err());
274 }
275}