1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::time::{Duration, Instant};
4use std::sync::Arc;
5use reqwest::{Client, Method, Response};
6use serde_json::Value as JsonValue;
7use url::Url;
8use dashmap::DashMap;
9use tokio::time::sleep;
10
11use crate::connectors::{Connector, ConnectorInitConfig, ConnectorCapabilities};
12use crate::utils::{
13 types::{
14 ConnectorType, ConnectorQuery, QueryResult, Schema, ColumnMetadata, DataType,
15 Row, Value, PredicateOperator, PredicateValue
16 },
17 error::{ConnectorError, NirvResult},
18};
19
20#[derive(Debug, Clone)]
22pub enum AuthConfig {
23 None,
24 ApiKey { header: String, key: String },
25 Bearer { token: String },
26 Basic { username: String, password: String },
27}
28
29#[derive(Debug, Clone)]
31struct CacheEntry {
32 data: JsonValue,
33 timestamp: Instant,
34 ttl: Duration,
35}
36
37impl CacheEntry {
38 fn new(data: JsonValue, ttl: Duration) -> Self {
39 Self {
40 data,
41 timestamp: Instant::now(),
42 ttl,
43 }
44 }
45
46 fn is_expired(&self) -> bool {
47 self.timestamp.elapsed() > self.ttl
48 }
49}
50
51#[derive(Debug, Clone)]
53pub struct RateLimitConfig {
54 pub requests_per_second: f64,
55 pub burst_size: u32,
56}
57
58impl Default for RateLimitConfig {
59 fn default() -> Self {
60 Self {
61 requests_per_second: 10.0,
62 burst_size: 10,
63 }
64 }
65}
66
67#[derive(Debug)]
69struct RateLimiter {
70 config: RateLimitConfig,
71 last_request: Option<Instant>,
72 tokens: f64,
73}
74
75impl RateLimiter {
76 fn new(config: RateLimitConfig) -> Self {
77 Self {
78 tokens: config.burst_size as f64,
79 config,
80 last_request: None,
81 }
82 }
83
84 async fn acquire(&mut self) -> NirvResult<()> {
85 let now = Instant::now();
86
87 if let Some(last) = self.last_request {
89 let elapsed = now.duration_since(last).as_secs_f64();
90 self.tokens = (self.tokens + elapsed * self.config.requests_per_second)
91 .min(self.config.burst_size as f64);
92 }
93
94 if self.tokens >= 1.0 {
95 self.tokens -= 1.0;
96 self.last_request = Some(now);
97 Ok(())
98 } else {
99 let wait_time = Duration::from_secs_f64(1.0 / self.config.requests_per_second);
101 sleep(wait_time).await;
102 self.tokens = (self.config.burst_size as f64 - 1.0).max(0.0);
103 self.last_request = Some(Instant::now());
104 Ok(())
105 }
106 }
107}
108
109pub struct RestConnector {
111 client: Option<Client>,
112 base_url: Option<Url>,
113 auth_config: AuthConfig,
114 cache: Arc<DashMap<String, CacheEntry>>,
115 cache_ttl: Duration,
116 rate_limiter: Option<RateLimiter>,
117 connected: bool,
118 endpoint_mappings: HashMap<String, EndpointMapping>,
119}
120
121#[derive(Debug, Clone)]
123pub struct EndpointMapping {
124 pub path: String,
125 pub method: Method,
126 pub query_params: HashMap<String, String>,
127 pub response_path: Option<String>, pub id_field: Option<String>, }
130
131impl RestConnector {
132 pub fn new() -> Self {
134 Self {
135 client: None,
136 base_url: None,
137 auth_config: AuthConfig::None,
138 cache: Arc::new(DashMap::new()),
139 cache_ttl: Duration::from_secs(300), rate_limiter: None,
141 connected: false,
142 endpoint_mappings: HashMap::new(),
143 }
144 }
145
146 pub fn with_auth(mut self, auth: AuthConfig) -> Self {
148 self.auth_config = auth;
149 self
150 }
151
152 pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
154 self.cache_ttl = ttl;
155 self
156 }
157
158 pub fn with_rate_limit(mut self, config: RateLimitConfig) -> Self {
160 self.rate_limiter = Some(RateLimiter::new(config));
161 self
162 }
163
164 pub fn add_endpoint_mapping(&mut self, name: String, mapping: EndpointMapping) {
166 self.endpoint_mappings.insert(name, mapping);
167 }
168
169 async fn build_request(&self, method: Method, url: &Url) -> NirvResult<reqwest::RequestBuilder> {
171 let client = self.client.as_ref()
172 .ok_or_else(|| ConnectorError::ConnectionFailed("Not connected".to_string()))?;
173
174 let mut request = client.request(method, url.clone());
175
176 match &self.auth_config {
178 AuthConfig::None => {},
179 AuthConfig::ApiKey { header, key } => {
180 request = request.header(header, key);
181 },
182 AuthConfig::Bearer { token } => {
183 request = request.bearer_auth(token);
184 },
185 AuthConfig::Basic { username, password } => {
186 request = request.basic_auth(username, Some(password));
187 },
188 }
189
190 Ok(request)
191 }
192
193 async fn execute_request(&mut self, method: Method, url: &Url) -> NirvResult<Response> {
195 if let Some(ref mut limiter) = self.rate_limiter {
197 limiter.acquire().await?;
198 }
199
200 let request = self.build_request(method, url).await?;
201
202 let response = request.send().await
203 .map_err(|e| ConnectorError::QueryExecutionFailed(
204 format!("HTTP request failed: {}", e)
205 ))?;
206
207 if !response.status().is_success() {
208 return Err(ConnectorError::QueryExecutionFailed(
209 format!("HTTP request failed with status: {}", response.status())
210 ).into());
211 }
212
213 Ok(response)
214 }
215
216 async fn get_cached_or_fetch(&mut self, cache_key: &str, url: &Url, method: Method) -> NirvResult<JsonValue> {
218 if let Some(entry) = self.cache.get(cache_key) {
220 if !entry.is_expired() {
221 return Ok(entry.data.clone());
222 }
223 }
224
225 let response = self.execute_request(method, url).await?;
227 let json_data: JsonValue = response.json().await
228 .map_err(|e| ConnectorError::QueryExecutionFailed(
229 format!("Failed to parse JSON response: {}", e)
230 ))?;
231
232 let entry = CacheEntry::new(json_data.clone(), self.cache_ttl);
234 self.cache.insert(cache_key.to_string(), entry);
235
236 Ok(json_data)
237 }
238
239 fn extract_data_array(&self, json: &JsonValue, path: Option<&str>) -> NirvResult<Vec<JsonValue>> {
241 match path {
242 Some(json_path) => {
243 let parts: Vec<&str> = json_path.split('.').collect();
245 let mut current = json;
246
247 for part in parts {
248 if part.is_empty() {
249 continue;
250 }
251
252 current = current.get(part)
253 .ok_or_else(|| ConnectorError::QueryExecutionFailed(
254 format!("JSONPath '{}' not found in response", json_path)
255 ))?;
256 }
257
258 match current {
259 JsonValue::Array(arr) => Ok(arr.clone()),
260 _ => Err(ConnectorError::QueryExecutionFailed(
261 format!("JSONPath '{}' does not point to an array", json_path)
262 ).into()),
263 }
264 },
265 None => {
266 match json {
267 JsonValue::Array(arr) => Ok(arr.clone()),
268 JsonValue::Object(_) => Ok(vec![json.clone()]),
269 _ => Err(ConnectorError::QueryExecutionFailed(
270 "Response is not an array or object".to_string()
271 ).into()),
272 }
273 }
274 }
275 }
276
277 fn json_to_row(&self, json_obj: &JsonValue, columns: &[ColumnMetadata]) -> Row {
279 let mut values = Vec::new();
280
281 for column in columns {
282 let value = if let JsonValue::Object(obj) = json_obj {
283 obj.get(&column.name)
284 .map(|v| self.json_value_to_value(v))
285 .unwrap_or(Value::Null)
286 } else {
287 Value::Null
288 };
289 values.push(value);
290 }
291
292 Row::new(values)
293 }
294
295 fn json_value_to_value(&self, json_val: &JsonValue) -> Value {
297 match json_val {
298 JsonValue::Null => Value::Null,
299 JsonValue::Bool(b) => Value::Boolean(*b),
300 JsonValue::Number(n) => {
301 if let Some(i) = n.as_i64() {
302 Value::Integer(i)
303 } else if let Some(f) = n.as_f64() {
304 Value::Float(f)
305 } else {
306 Value::Text(n.to_string())
307 }
308 },
309 JsonValue::String(s) => Value::Text(s.clone()),
310 JsonValue::Array(_) | JsonValue::Object(_) => {
311 Value::Json(json_val.to_string())
312 },
313 }
314 }
315
316 fn infer_schema_from_json(&self, data: &[JsonValue], object_name: &str) -> Schema {
318 let mut columns = Vec::new();
319
320 if let Some(first_obj) = data.first() {
321 if let JsonValue::Object(obj) = first_obj {
322 for (key, value) in obj {
323 let data_type = match value {
324 JsonValue::Null => DataType::Text,
325 JsonValue::Bool(_) => DataType::Boolean,
326 JsonValue::Number(n) => {
327 if n.is_i64() {
328 DataType::Integer
329 } else {
330 DataType::Float
331 }
332 },
333 JsonValue::String(_) => DataType::Text,
334 JsonValue::Array(_) | JsonValue::Object(_) => DataType::Json,
335 };
336
337 columns.push(ColumnMetadata {
338 name: key.clone(),
339 data_type,
340 nullable: true,
341 });
342 }
343 }
344 }
345
346 Schema {
347 name: object_name.to_string(),
348 columns,
349 primary_key: None,
350 indexes: Vec::new(),
351 }
352 }
353
354 fn apply_predicates(&self, data: Vec<JsonValue>, predicates: &[crate::utils::types::Predicate]) -> Vec<JsonValue> {
356 if predicates.is_empty() {
357 return data;
358 }
359
360 data.into_iter()
361 .filter(|item| {
362 if let JsonValue::Object(obj) = item {
363 predicates.iter().all(|predicate| {
364 if let Some(field_value) = obj.get(&predicate.column) {
365 let value = self.json_value_to_value(field_value);
366 self.evaluate_predicate(&value, &predicate.operator, &predicate.value)
367 } else {
368 false
369 }
370 })
371 } else {
372 false
373 }
374 })
375 .collect()
376 }
377
378 fn evaluate_predicate(&self, value: &Value, operator: &PredicateOperator, predicate_value: &PredicateValue) -> bool {
380 match operator {
381 PredicateOperator::Equal => self.values_equal(value, predicate_value),
382 PredicateOperator::NotEqual => !self.values_equal(value, predicate_value),
383 PredicateOperator::GreaterThan => self.value_greater_than(value, predicate_value),
384 PredicateOperator::GreaterThanOrEqual => {
385 self.value_greater_than(value, predicate_value) || self.values_equal(value, predicate_value)
386 },
387 PredicateOperator::LessThan => self.value_less_than(value, predicate_value),
388 PredicateOperator::LessThanOrEqual => {
389 self.value_less_than(value, predicate_value) || self.values_equal(value, predicate_value)
390 },
391 PredicateOperator::Like => self.value_like(value, predicate_value),
392 PredicateOperator::In => self.value_in(value, predicate_value),
393 PredicateOperator::IsNull => matches!(value, Value::Null),
394 PredicateOperator::IsNotNull => !matches!(value, Value::Null),
395 }
396 }
397
398 fn values_equal(&self, value: &Value, predicate_value: &PredicateValue) -> bool {
400 match (value, predicate_value) {
401 (Value::Text(v), PredicateValue::String(p)) => v == p,
402 (Value::Integer(v), PredicateValue::Integer(p)) => v == p,
403 (Value::Float(v), PredicateValue::Number(p)) => (v - p).abs() < f64::EPSILON,
404 (Value::Boolean(v), PredicateValue::Boolean(p)) => v == p,
405 (Value::Null, PredicateValue::Null) => true,
406 (Value::Integer(v), PredicateValue::Number(p)) => (*v as f64 - p).abs() < f64::EPSILON,
408 (Value::Float(v), PredicateValue::Integer(p)) => (v - *p as f64).abs() < f64::EPSILON,
409 _ => false,
410 }
411 }
412
413 fn value_greater_than(&self, value: &Value, predicate_value: &PredicateValue) -> bool {
415 match (value, predicate_value) {
416 (Value::Integer(v), PredicateValue::Integer(p)) => v > p,
417 (Value::Float(v), PredicateValue::Number(p)) => v > p,
418 (Value::Integer(v), PredicateValue::Number(p)) => (*v as f64) > *p,
419 (Value::Float(v), PredicateValue::Integer(p)) => *v > (*p as f64),
420 (Value::Text(v), PredicateValue::String(p)) => v > p,
421 _ => false,
422 }
423 }
424
425 fn value_less_than(&self, value: &Value, predicate_value: &PredicateValue) -> bool {
427 match (value, predicate_value) {
428 (Value::Integer(v), PredicateValue::Integer(p)) => v < p,
429 (Value::Float(v), PredicateValue::Number(p)) => v < p,
430 (Value::Integer(v), PredicateValue::Number(p)) => (*v as f64) < *p,
431 (Value::Float(v), PredicateValue::Integer(p)) => *v < (*p as f64),
432 (Value::Text(v), PredicateValue::String(p)) => v < p,
433 _ => false,
434 }
435 }
436
437 fn value_like(&self, value: &Value, predicate_value: &PredicateValue) -> bool {
439 match (value, predicate_value) {
440 (Value::Text(v), PredicateValue::String(pattern)) => {
441 let regex_pattern = pattern
442 .replace('%', ".*")
443 .replace('_', ".");
444
445 if let Ok(regex) = regex::Regex::new(&format!("^{}$", regex_pattern)) {
446 regex.is_match(v)
447 } else {
448 false
449 }
450 },
451 _ => false,
452 }
453 }
454
455 fn value_in(&self, value: &Value, predicate_value: &PredicateValue) -> bool {
457 match predicate_value {
458 PredicateValue::List(list) => {
459 list.iter().any(|item| self.values_equal(value, item))
460 },
461 _ => false,
462 }
463 }
464}
465
466impl Default for RestConnector {
467 fn default() -> Self {
468 Self::new()
469 }
470}
471
472#[async_trait]
473impl Connector for RestConnector {
474 async fn connect(&mut self, config: ConnectorInitConfig) -> NirvResult<()> {
475 let base_url_str = config.connection_params.get("base_url")
476 .ok_or_else(|| ConnectorError::ConnectionFailed(
477 "base_url parameter is required".to_string()
478 ))?;
479
480 let base_url = Url::parse(base_url_str)
481 .map_err(|e| ConnectorError::ConnectionFailed(
482 format!("Invalid base URL: {}", e)
483 ))?;
484
485 if let Some(auth_type) = config.connection_params.get("auth_type") {
487 self.auth_config = match auth_type.as_str() {
488 "api_key" => {
489 let header = config.connection_params.get("auth_header")
490 .unwrap_or(&"X-API-Key".to_string()).clone();
491 let key = config.connection_params.get("api_key")
492 .ok_or_else(|| ConnectorError::ConnectionFailed(
493 "api_key parameter is required for API key auth".to_string()
494 ))?.clone();
495 AuthConfig::ApiKey { header, key }
496 },
497 "bearer" => {
498 let token = config.connection_params.get("bearer_token")
499 .ok_or_else(|| ConnectorError::ConnectionFailed(
500 "bearer_token parameter is required for bearer auth".to_string()
501 ))?.clone();
502 AuthConfig::Bearer { token }
503 },
504 "basic" => {
505 let username = config.connection_params.get("username")
506 .ok_or_else(|| ConnectorError::ConnectionFailed(
507 "username parameter is required for basic auth".to_string()
508 ))?.clone();
509 let password = config.connection_params.get("password")
510 .ok_or_else(|| ConnectorError::ConnectionFailed(
511 "password parameter is required for basic auth".to_string()
512 ))?.clone();
513 AuthConfig::Basic { username, password }
514 },
515 "none" | _ => AuthConfig::None,
516 };
517 }
518
519 if let Some(cache_ttl_str) = config.connection_params.get("cache_ttl_seconds") {
521 if let Ok(ttl_seconds) = cache_ttl_str.parse::<u64>() {
522 self.cache_ttl = Duration::from_secs(ttl_seconds);
523 }
524 }
525
526 if let Some(rps_str) = config.connection_params.get("rate_limit_rps") {
528 if let Ok(rps) = rps_str.parse::<f64>() {
529 let burst_size = config.connection_params.get("rate_limit_burst")
530 .and_then(|s| s.parse::<u32>().ok())
531 .unwrap_or(10);
532
533 let rate_config = RateLimitConfig {
534 requests_per_second: rps,
535 burst_size,
536 };
537 self.rate_limiter = Some(RateLimiter::new(rate_config));
538 }
539 }
540
541 let timeout = Duration::from_secs(config.timeout_seconds.unwrap_or(30));
543 let client = Client::builder()
544 .timeout(timeout)
545 .build()
546 .map_err(|e| ConnectorError::ConnectionFailed(
547 format!("Failed to create HTTP client: {}", e)
548 ))?;
549
550 self.client = Some(client);
551 self.base_url = Some(base_url);
552 self.connected = true;
553
554 Ok(())
555 }
556
557 async fn execute_query(&self, query: ConnectorQuery) -> NirvResult<QueryResult> {
558 if !self.connected {
559 return Err(ConnectorError::ConnectionFailed("Not connected".to_string()).into());
560 }
561
562 if query.query.sources.is_empty() {
563 return Err(ConnectorError::QueryExecutionFailed(
564 "No data source specified in query".to_string()
565 ).into());
566 }
567
568 let source = &query.query.sources[0];
569 let endpoint_name = &source.identifier;
570
571 let mapping = self.endpoint_mappings.get(endpoint_name)
573 .ok_or_else(|| ConnectorError::QueryExecutionFailed(
574 format!("No endpoint mapping found for '{}'", endpoint_name)
575 ))?;
576
577 let base_url = self.base_url.as_ref()
578 .ok_or_else(|| ConnectorError::ConnectionFailed("Not connected".to_string()))?;
579
580 let mut url = base_url.join(&mapping.path)
581 .map_err(|e| ConnectorError::QueryExecutionFailed(
582 format!("Failed to build URL: {}", e)
583 ))?;
584
585 {
587 let mut query_pairs = url.query_pairs_mut();
588 for (key, value) in &mapping.query_params {
589 query_pairs.append_pair(key, value);
590 }
591 }
592
593 let start_time = Instant::now();
594 let cache_key = format!("{}:{}", endpoint_name, url.as_str());
595
596 let mut temp_connector = RestConnector {
599 client: self.client.clone(),
600 base_url: self.base_url.clone(),
601 auth_config: self.auth_config.clone(),
602 cache: self.cache.clone(),
603 cache_ttl: self.cache_ttl,
604 rate_limiter: None, connected: self.connected,
606 endpoint_mappings: self.endpoint_mappings.clone(),
607 };
608
609 let json_data = temp_connector.get_cached_or_fetch(&cache_key, &url, mapping.method.clone()).await?;
610 let data_array = temp_connector.extract_data_array(&json_data, mapping.response_path.as_deref())?;
611
612 let filtered_data = temp_connector.apply_predicates(data_array, &query.query.predicates);
614
615 let schema = temp_connector.infer_schema_from_json(&filtered_data, endpoint_name);
617
618 let mut rows = Vec::new();
620 for item in &filtered_data {
621 let row = temp_connector.json_to_row(item, &schema.columns);
622 rows.push(row);
623 }
624
625 if let Some(limit) = query.query.limit {
627 rows.truncate(limit as usize);
628 }
629
630 let execution_time = start_time.elapsed();
631
632 Ok(QueryResult {
633 columns: schema.columns,
634 rows,
635 affected_rows: Some(filtered_data.len() as u64),
636 execution_time,
637 })
638 }
639
640 async fn get_schema(&self, object_name: &str) -> NirvResult<Schema> {
641 if !self.connected {
642 return Err(ConnectorError::ConnectionFailed("Not connected".to_string()).into());
643 }
644
645 let mapping = self.endpoint_mappings.get(object_name)
647 .ok_or_else(|| ConnectorError::SchemaRetrievalFailed(
648 format!("No endpoint mapping found for '{}'", object_name)
649 ))?;
650
651 let base_url = self.base_url.as_ref()
652 .ok_or_else(|| ConnectorError::ConnectionFailed("Not connected".to_string()))?;
653
654 let mut url = base_url.join(&mapping.path)
655 .map_err(|e| ConnectorError::SchemaRetrievalFailed(
656 format!("Failed to build URL: {}", e)
657 ))?;
658
659 {
661 let mut query_pairs = url.query_pairs_mut();
662 for (key, value) in &mapping.query_params {
663 query_pairs.append_pair(key, value);
664 }
665 }
666
667 let cache_key = format!("schema:{}:{}", object_name, url.as_str());
668
669 let mut temp_connector = RestConnector {
671 client: self.client.clone(),
672 base_url: self.base_url.clone(),
673 auth_config: self.auth_config.clone(),
674 cache: self.cache.clone(),
675 cache_ttl: self.cache_ttl,
676 rate_limiter: None,
677 connected: self.connected,
678 endpoint_mappings: self.endpoint_mappings.clone(),
679 };
680
681 let json_data = temp_connector.get_cached_or_fetch(&cache_key, &url, mapping.method.clone()).await?;
682 let data_array = temp_connector.extract_data_array(&json_data, mapping.response_path.as_deref())?;
683
684 Ok(temp_connector.infer_schema_from_json(&data_array, object_name))
685 }
686
687 async fn disconnect(&mut self) -> NirvResult<()> {
688 self.client = None;
689 self.base_url = None;
690 self.connected = false;
691 self.cache.clear();
692 Ok(())
693 }
694
695 fn get_connector_type(&self) -> ConnectorType {
696 ConnectorType::Rest
697 }
698
699 fn supports_transactions(&self) -> bool {
700 false }
702
703 fn is_connected(&self) -> bool {
704 self.connected
705 }
706
707 fn get_capabilities(&self) -> ConnectorCapabilities {
708 ConnectorCapabilities {
709 supports_joins: false, supports_aggregations: true, supports_subqueries: false,
712 supports_transactions: false,
713 supports_schema_introspection: true,
714 max_concurrent_queries: Some(5), }
716 }
717}