1use crate::{
7 ai_response::{expand_prompt_template, AiResponseConfig, RequestContext},
8 OpenApiSpec, Result,
9};
10use async_trait::async_trait;
11use chrono;
12use openapiv3::{Operation, ReferenceOr, Response, Responses, Schema};
13use rand::{rng, Rng};
14use serde_json::Value;
15use std::collections::HashMap;
16use uuid;
17
18#[async_trait]
23pub trait AiGenerator: Send + Sync {
24 async fn generate(&self, prompt: &str, config: &AiResponseConfig) -> Result<Value>;
33}
34
35pub struct ResponseGenerator;
37
38impl ResponseGenerator {
39 pub async fn generate_ai_response(
52 ai_config: &AiResponseConfig,
53 context: &RequestContext,
54 generator: Option<&dyn AiGenerator>,
55 ) -> Result<Value> {
56 let prompt_template = ai_config
58 .prompt
59 .as_ref()
60 .ok_or_else(|| crate::Error::generic("AI prompt is required"))?;
61
62 let expanded_prompt = expand_prompt_template(prompt_template, context);
63
64 tracing::info!("AI response generation requested with prompt: {}", expanded_prompt);
65
66 if let Some(gen) = generator {
68 tracing::debug!("Using provided AI generator for response");
69 return gen.generate(&expanded_prompt, ai_config).await;
70 }
71
72 tracing::warn!("No AI generator provided, returning placeholder response");
74 Ok(serde_json::json!({
75 "ai_response": "AI generation placeholder",
76 "note": "This endpoint is configured for AI-assisted responses, but no AI generator was provided",
77 "expanded_prompt": expanded_prompt,
78 "mode": format!("{:?}", ai_config.mode),
79 "temperature": ai_config.temperature,
80 "implementation_note": "Pass an AiGenerator implementation to ResponseGenerator::generate_ai_response to enable actual AI generation"
81 }))
82 }
83
84 pub fn generate_response(
86 spec: &OpenApiSpec,
87 operation: &Operation,
88 status_code: u16,
89 content_type: Option<&str>,
90 ) -> Result<Value> {
91 Self::generate_response_with_expansion(spec, operation, status_code, content_type, true)
92 }
93
94 pub fn generate_response_with_expansion(
96 spec: &OpenApiSpec,
97 operation: &Operation,
98 status_code: u16,
99 content_type: Option<&str>,
100 expand_tokens: bool,
101 ) -> Result<Value> {
102 let response = Self::find_response_for_status(&operation.responses, status_code);
104
105 match response {
106 Some(response_ref) => {
107 match response_ref {
108 ReferenceOr::Item(response) => {
109 Self::generate_from_response(spec, response, content_type, expand_tokens)
110 }
111 ReferenceOr::Reference { reference } => {
112 if let Some(resolved_response) = spec.get_response(reference) {
114 Self::generate_from_response(
115 spec,
116 resolved_response,
117 content_type,
118 expand_tokens,
119 )
120 } else {
121 Ok(Value::Object(serde_json::Map::new()))
123 }
124 }
125 }
126 }
127 None => {
128 Ok(Value::Object(serde_json::Map::new()))
130 }
131 }
132 }
133
134 fn find_response_for_status(
136 responses: &Responses,
137 status_code: u16,
138 ) -> Option<&ReferenceOr<Response>> {
139 if let Some(response) = responses.responses.get(&openapiv3::StatusCode::Code(status_code)) {
141 return Some(response);
142 }
143
144 if let Some(default_response) = &responses.default {
146 return Some(default_response);
147 }
148
149 None
150 }
151
152 fn generate_from_response(
154 spec: &OpenApiSpec,
155 response: &Response,
156 content_type: Option<&str>,
157 expand_tokens: bool,
158 ) -> Result<Value> {
159 if let Some(content_type) = content_type {
161 if let Some(media_type) = response.content.get(content_type) {
162 return Self::generate_from_media_type(spec, media_type, expand_tokens);
163 }
164 }
165
166 let preferred_types = ["application/json", "application/xml", "text/plain"];
168
169 for content_type in &preferred_types {
170 if let Some(media_type) = response.content.get(*content_type) {
171 return Self::generate_from_media_type(spec, media_type, expand_tokens);
172 }
173 }
174
175 if let Some((_, media_type)) = response.content.iter().next() {
177 return Self::generate_from_media_type(spec, media_type, expand_tokens);
178 }
179
180 Ok(Value::Object(serde_json::Map::new()))
182 }
183
184 fn generate_from_media_type(
186 spec: &OpenApiSpec,
187 media_type: &openapiv3::MediaType,
188 expand_tokens: bool,
189 ) -> Result<Value> {
190 if let Some(example) = &media_type.example {
192 tracing::debug!("Using explicit example from media type: {:?}", example);
193 if expand_tokens {
195 let expanded_example = Self::expand_templates(example);
196 return Ok(expanded_example);
197 } else {
198 return Ok(example.clone());
199 }
200 }
201
202 if !media_type.examples.is_empty() {
204 if let Some((_, example_ref)) = media_type.examples.iter().next() {
205 match example_ref {
206 ReferenceOr::Item(example) => {
207 if let Some(value) = &example.value {
208 tracing::debug!("Using example from examples map: {:?}", value);
209 if expand_tokens {
210 return Ok(Self::expand_templates(value));
211 } else {
212 return Ok(value.clone());
213 }
214 }
215 }
216 ReferenceOr::Reference { reference } => {
217 if let Some(example) = spec.get_example(reference) {
219 if let Some(value) = &example.value {
220 tracing::debug!("Using resolved example reference: {:?}", value);
221 if expand_tokens {
222 return Ok(Self::expand_templates(value));
223 } else {
224 return Ok(value.clone());
225 }
226 }
227 } else {
228 tracing::warn!("Example reference '{}' not found", reference);
229 }
230 }
231 }
232 }
233 }
234
235 if let Some(schema_ref) = &media_type.schema {
237 Ok(Self::generate_example_from_schema_ref(spec, schema_ref))
238 } else {
239 Ok(Value::Object(serde_json::Map::new()))
240 }
241 }
242
243 fn generate_example_from_schema_ref(
244 spec: &OpenApiSpec,
245 schema_ref: &ReferenceOr<Schema>,
246 ) -> Value {
247 match schema_ref {
248 ReferenceOr::Item(schema) => Self::generate_example_from_schema(spec, schema),
249 ReferenceOr::Reference { reference } => spec
250 .get_schema(reference)
251 .map(|schema| Self::generate_example_from_schema(spec, &schema.schema))
252 .unwrap_or_else(|| Value::Object(serde_json::Map::new())),
253 }
254 }
255
256 fn generate_example_from_schema(spec: &OpenApiSpec, schema: &Schema) -> Value {
258 match &schema.schema_kind {
259 openapiv3::SchemaKind::Type(openapiv3::Type::String(_)) => {
260 Value::String("example string".to_string())
262 }
263 openapiv3::SchemaKind::Type(openapiv3::Type::Integer(_)) => Value::Number(42.into()),
264 openapiv3::SchemaKind::Type(openapiv3::Type::Number(_)) => {
265 Value::Number(serde_json::Number::from_f64(std::f64::consts::PI).unwrap())
266 }
267 openapiv3::SchemaKind::Type(openapiv3::Type::Boolean(_)) => Value::Bool(true),
268 openapiv3::SchemaKind::Type(openapiv3::Type::Object(obj)) => {
269 let mut map = serde_json::Map::new();
270 for (prop_name, prop_schema) in &obj.properties {
271 let value = match prop_schema {
272 ReferenceOr::Item(prop_schema) => {
273 Self::generate_example_from_schema(spec, prop_schema.as_ref())
274 }
275 ReferenceOr::Reference { reference } => spec
276 .get_schema(reference)
277 .map(|schema| Self::generate_example_from_schema(spec, &schema.schema))
278 .unwrap_or_else(|| Self::generate_example_for_property(prop_name)),
279 };
280 let value = match value {
281 Value::Null => Self::generate_example_for_property(prop_name),
282 Value::Object(ref obj) if obj.is_empty() => {
283 Self::generate_example_for_property(prop_name)
284 }
285 _ => value,
286 };
287 map.insert(prop_name.clone(), value);
288 }
289 Value::Object(map)
290 }
291 openapiv3::SchemaKind::Type(openapiv3::Type::Array(arr)) => match &arr.items {
292 Some(item_schema) => {
293 let example_item = match item_schema {
294 ReferenceOr::Item(item_schema) => {
295 Self::generate_example_from_schema(spec, item_schema.as_ref())
296 }
297 ReferenceOr::Reference { reference } => spec
298 .get_schema(reference)
299 .map(|schema| Self::generate_example_from_schema(spec, &schema.schema))
300 .unwrap_or_else(|| Value::Object(serde_json::Map::new())),
301 };
302 Value::Array(vec![example_item])
303 }
304 None => Value::Array(vec![Value::String("item".to_string())]),
305 },
306 _ => Value::Object(serde_json::Map::new()),
307 }
308 }
309
310 fn generate_example_for_property(prop_name: &str) -> Value {
312 let prop_lower = prop_name.to_lowercase();
313
314 if prop_lower.contains("id") || prop_lower.contains("uuid") {
316 Value::String(uuid::Uuid::new_v4().to_string())
317 } else if prop_lower.contains("email") {
318 Value::String(format!("user{}@example.com", rng().random_range(1000..=9999)))
319 } else if prop_lower.contains("name") || prop_lower.contains("title") {
320 let names = ["John Doe", "Jane Smith", "Bob Johnson", "Alice Brown"];
321 Value::String(names[rng().random_range(0..names.len())].to_string())
322 } else if prop_lower.contains("phone") || prop_lower.contains("mobile") {
323 Value::String(format!("+1-555-{:04}", rng().random_range(1000..=9999)))
324 } else if prop_lower.contains("address") || prop_lower.contains("street") {
325 let streets = ["123 Main St", "456 Oak Ave", "789 Pine Rd", "321 Elm St"];
326 Value::String(streets[rng().random_range(0..streets.len())].to_string())
327 } else if prop_lower.contains("city") {
328 let cities = ["New York", "London", "Tokyo", "Paris", "Sydney"];
329 Value::String(cities[rng().random_range(0..cities.len())].to_string())
330 } else if prop_lower.contains("country") {
331 let countries = ["USA", "UK", "Japan", "France", "Australia"];
332 Value::String(countries[rng().random_range(0..countries.len())].to_string())
333 } else if prop_lower.contains("company") || prop_lower.contains("organization") {
334 let companies = ["Acme Corp", "Tech Solutions", "Global Inc", "Innovate Ltd"];
335 Value::String(companies[rng().random_range(0..companies.len())].to_string())
336 } else if prop_lower.contains("url") || prop_lower.contains("website") {
337 Value::String("https://example.com".to_string())
338 } else if prop_lower.contains("age") {
339 Value::Number((18 + rng().random_range(0..60)).into())
340 } else if prop_lower.contains("count") || prop_lower.contains("quantity") {
341 Value::Number((1 + rng().random_range(0..100)).into())
342 } else if prop_lower.contains("price")
343 || prop_lower.contains("amount")
344 || prop_lower.contains("cost")
345 {
346 Value::Number(
347 serde_json::Number::from_f64(
348 (rng().random::<f64>() * 1000.0 * 100.0).round() / 100.0,
349 )
350 .unwrap(),
351 )
352 } else if prop_lower.contains("active")
353 || prop_lower.contains("enabled")
354 || prop_lower.contains("is_")
355 {
356 Value::Bool(rng().random_bool(0.5))
357 } else if prop_lower.contains("date") || prop_lower.contains("time") {
358 Value::String(chrono::Utc::now().to_rfc3339())
359 } else if prop_lower.contains("description") || prop_lower.contains("comment") {
360 Value::String("This is a sample description text.".to_string())
361 } else {
362 Value::String(format!("example {}", prop_name))
363 }
364 }
365
366 pub fn generate_from_examples(
368 response: &Response,
369 content_type: Option<&str>,
370 ) -> Result<Option<Value>> {
371 use openapiv3::ReferenceOr;
372
373 if let Some(content_type) = content_type {
375 if let Some(media_type) = response.content.get(content_type) {
376 if let Some(example) = &media_type.example {
378 return Ok(Some(example.clone()));
379 }
380
381 for (_, example_ref) in &media_type.examples {
383 if let ReferenceOr::Item(example) = example_ref {
384 if let Some(value) = &example.value {
385 return Ok(Some(value.clone()));
386 }
387 }
388 }
390 }
391 }
392
393 for (_, media_type) in &response.content {
395 if let Some(example) = &media_type.example {
397 return Ok(Some(example.clone()));
398 }
399
400 for (_, example_ref) in &media_type.examples {
402 if let ReferenceOr::Item(example) = example_ref {
403 if let Some(value) = &example.value {
404 return Ok(Some(value.clone()));
405 }
406 }
407 }
409 }
410
411 Ok(None)
412 }
413
414 fn expand_templates(value: &Value) -> Value {
416 match value {
417 Value::String(s) => {
418 let expanded = s
419 .replace("{{now}}", &chrono::Utc::now().to_rfc3339())
420 .replace("{{uuid}}", &uuid::Uuid::new_v4().to_string());
421 Value::String(expanded)
422 }
423 Value::Object(map) => {
424 let mut new_map = serde_json::Map::new();
425 for (key, val) in map {
426 new_map.insert(key.clone(), Self::expand_templates(val));
427 }
428 Value::Object(new_map)
429 }
430 Value::Array(arr) => {
431 let new_arr: Vec<Value> = arr.iter().map(Self::expand_templates).collect();
432 Value::Array(new_arr)
433 }
434 _ => value.clone(),
435 }
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442 use openapiv3::ReferenceOr;
443
444 #[test]
445 fn generates_example_using_referenced_schemas() {
446 let yaml = r#"
447openapi: 3.0.3
448info:
449 title: Test API
450 version: "1.0.0"
451paths:
452 /apiaries:
453 get:
454 responses:
455 '200':
456 description: ok
457 content:
458 application/json:
459 schema:
460 $ref: '#/components/schemas/Apiary'
461components:
462 schemas:
463 Apiary:
464 type: object
465 properties:
466 id:
467 type: string
468 hive:
469 $ref: '#/components/schemas/Hive'
470 Hive:
471 type: object
472 properties:
473 name:
474 type: string
475 active:
476 type: boolean
477 "#;
478
479 let spec = OpenApiSpec::from_string(yaml, Some("yaml")).expect("load spec");
480 let path_item = spec
481 .spec
482 .paths
483 .paths
484 .get("/apiaries")
485 .and_then(ReferenceOr::as_item)
486 .expect("path item");
487 let operation = path_item.get.as_ref().expect("GET operation");
488
489 let response =
490 ResponseGenerator::generate_response(&spec, operation, 200, Some("application/json"))
491 .expect("generate response");
492
493 let obj = response.as_object().expect("response object");
494 assert!(obj.contains_key("id"));
495 let hive = obj.get("hive").and_then(|value| value.as_object()).expect("hive object");
496 assert!(hive.contains_key("name"));
497 assert!(hive.contains_key("active"));
498 }
499}
500
501#[derive(Debug, Clone)]
503pub struct MockResponse {
504 pub status_code: u16,
506 pub headers: HashMap<String, String>,
508 pub body: Option<Value>,
510}
511
512impl MockResponse {
513 pub fn new(status_code: u16) -> Self {
515 Self {
516 status_code,
517 headers: HashMap::new(),
518 body: None,
519 }
520 }
521
522 pub fn with_header(mut self, name: String, value: String) -> Self {
524 self.headers.insert(name, value);
525 self
526 }
527
528 pub fn with_body(mut self, body: Value) -> Self {
530 self.body = Some(body);
531 self
532 }
533}
534
535#[derive(Debug, Clone)]
537pub struct OpenApiSecurityRequirement {
538 pub scheme: String,
540 pub scopes: Vec<String>,
542}
543
544impl OpenApiSecurityRequirement {
545 pub fn new(scheme: String, scopes: Vec<String>) -> Self {
547 Self { scheme, scopes }
548 }
549}
550
551#[derive(Debug, Clone)]
553pub struct OpenApiOperation {
554 pub method: String,
556 pub path: String,
558 pub operation: openapiv3::Operation,
560}
561
562impl OpenApiOperation {
563 pub fn new(method: String, path: String, operation: openapiv3::Operation) -> Self {
565 Self {
566 method,
567 path,
568 operation,
569 }
570 }
571}