1use std::collections::HashSet;
2use anyhow::{anyhow, Result};
3use serde::{Deserialize, Serialize};
4use crate::{OpenAPI, Parameter, RequestBody, Response, Schema};
5
6pub enum SchemaReference {
9 Schema {
10 schema: String,
11 },
12 Property {
13 schema: String,
14 property: String,
15 },
16}
17
18impl SchemaReference {
19 pub fn from_str(reference: &str) -> Self {
20 let mut ns = reference.rsplit('/');
21 let name = ns.next().unwrap();
22 match ns.next().unwrap() {
23 "schemas" => {
24 Self::Schema {
25 schema: name.to_string(),
26 }
27 }
28 "properties" => {
29 let schema_name = ns.next().unwrap();
30 Self::Property {
31 schema: schema_name.to_string(),
32 property: name.to_string(),
33 }
34 }
35 _ => panic!("Unknown reference: {}", reference),
36 }
37 }
38}
39
40
41impl std::fmt::Display for SchemaReference {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 match self {
44 SchemaReference::Schema { schema } => write!(f, "#/components/schemas/{}", schema),
45 SchemaReference::Property { schema, property } => write!(f, "#/components/schemas/{}/properties/{}", schema, property),
46 }
47 }
48}
49
50
51pub type ReferenceOr<T> = RefOr<T>;
53pub type RefOr<T> = Ref<T>;
54
55#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
56#[serde(untagged)]
57pub enum Ref<T> {
58 Reference {
59 #[serde(rename = "$ref")]
60 reference: String,
61 },
62 Item(T),
63}
64
65impl<T> Ref<T> {
66 pub fn ref_(r: &str) -> Self {
67 Ref::Reference {
68 reference: r.to_owned(),
69 }
70 }
71 pub fn schema_ref(r: &str) -> Self {
72 Ref::Reference {
73 reference: format!("#/components/schemas/{}", r),
74 }
75 }
76
77 pub fn boxed(self) -> Box<Ref<T>> {
78 Box::new(self)
79 }
80
81 pub fn into_item(self) -> Option<T> {
97 match self {
98 RefOr::Reference { .. } => None,
99 RefOr::Item(i) => Some(i),
100 }
101 }
102
103 pub fn as_item(&self) -> Option<&T> {
119 match self {
120 RefOr::Reference { .. } => None,
121 RefOr::Item(i) => Some(i),
122 }
123 }
124
125 pub fn as_ref_str(&self) -> Option<&str> {
126 match self {
127 RefOr::Reference { reference } => Some(reference),
128 RefOr::Item(_) => None,
129 }
130 }
131
132 pub fn as_mut(&mut self) -> Option<&mut T> {
133 match self {
134 RefOr::Reference { .. } => None,
135 RefOr::Item(i) => Some(i),
136 }
137 }
138
139 pub fn to_mut(&mut self) -> &mut T {
140 self.as_mut().expect("Not an item")
141 }
142}
143
144fn resolve_helper<'a>(reference: &str, spec: &'a OpenAPI, seen: &mut HashSet<String>) -> &'a Schema {
145 if seen.contains(reference) {
146 panic!("Circular reference: {}", reference);
147 }
148 seen.insert(reference.to_string());
149 let reference = SchemaReference::from_str(&reference);
150 match &reference {
151 SchemaReference::Schema { ref schema } => {
152 let schema_ref = spec.schemas.get(schema)
153 .expect(&format!("Schema {} not found in OpenAPI spec.", schema));
154 match schema_ref {
157 RefOr::Reference { reference } => {
158 resolve_helper(&reference, spec, seen)
159 }
160 RefOr::Item(s) => s
161 }
162 }
163 SchemaReference::Property { schema: schema_name, property } => {
164 let schema = spec.schemas.get(schema_name)
165 .expect(&format!("Schema {} not found in OpenAPI spec.", schema_name))
166 .as_item()
167 .expect(&format!("The schema {} was used in a reference, but that schema is itself a reference to another schema.", schema_name));
168 let prop_schema = schema
169 .properties()
170 .get(property)
171 .expect(&format!("Schema {} does not have property {}.", schema_name, property));
172 prop_schema.resolve(spec)
173 }
174 }
175}
176
177impl RefOr<Schema> {
178 pub fn resolve<'a>(&'a self, spec: &'a OpenAPI) -> &'a Schema {
179 match self {
180 RefOr::Reference { reference } => {
181 resolve_helper(reference, spec, &mut HashSet::new())
182 }
183 RefOr::Item(schema) => schema,
184 }
185 }
186}
187
188impl<T> From<T> for RefOr<T> {
189 fn from(item: T) -> Self {
190 RefOr::Item(item)
191 }
192}
193
194impl RefOr<Parameter> {
195 pub fn resolve<'a>(&'a self, spec: &'a OpenAPI) -> Result<&'a Parameter> {
196 match self {
197 RefOr::Reference { reference } => {
198 let name = get_parameter_name(&reference)?;
199 spec.parameters.get(name)
200 .ok_or(anyhow!("{} not found in OpenAPI spec.", reference))?
201 .as_item()
202 .ok_or(anyhow!("{} is circular.", reference))
203 }
204 RefOr::Item(parameter) => Ok(parameter),
205 }
206 }
207}
208
209
210impl RefOr<Response> {
211 pub fn resolve<'a>(&'a self, spec: &'a OpenAPI) -> Result<&'a Response> {
212 match self {
213 RefOr::Reference { reference } => {
214 let name = get_response_name(&reference)?;
215 spec.responses.get(name)
216 .ok_or(anyhow!("{} not found in OpenAPI spec.", reference))?
217 .as_item()
218 .ok_or(anyhow!("{} is circular.", reference))
219 }
220 RefOr::Item(response) => Ok(response),
221 }
222 }
223}
224
225impl Ref<RequestBody> {
226 pub fn resolve<'a>(&'a self, spec: &'a OpenAPI) -> Result<&'a RequestBody> {
227 match self {
228 Ref::Reference { reference } => {
229 let name = get_request_body_name(&reference)?;
230 spec.request_bodies.get(name)
231 .ok_or(anyhow!("{} not found in OpenAPI spec.", reference))?
232 .as_item()
233 .ok_or(anyhow!("{} is circular.", reference))
234 }
235 Ref::Item(request_body) => Ok(request_body),
236 }
237 }
238}
239
240impl<T: Default> Default for RefOr<T> {
241 fn default() -> Self {
242 Ref::Item(T::default())
243 }
244}
245
246fn parse_reference<'a>(reference: &'a str, group: &str) -> Result<&'a str> {
247 let mut parts = reference.rsplitn(2, '/');
248 let name = parts.next();
249 name.filter(|_| matches!(parts.next(), Some(x) if format!("#/components/{group}") == x))
250 .ok_or(anyhow!("Invalid {} reference: {}", group, reference))
251}
252
253
254fn get_response_name(reference: &str) -> Result<&str> {
255 parse_reference(reference, "responses")
256}
257
258
259fn get_request_body_name(reference: &str) -> Result<&str> {
260 parse_reference(reference, "requestBodies")
261}
262
263fn get_parameter_name(reference: &str) -> Result<&str> {
264 parse_reference(reference, "parameters")
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_get_request_body_name() {
273 assert!(matches!(get_request_body_name("#/components/requestBodies/Foo"), Ok("Foo")));
274 assert!(get_request_body_name("#/components/schemas/Foo").is_err());
275 }
276}