use reqwest::{Client, header};
use serde::{de::DeserializeOwned, Serialize};
use std::time::Duration;
use crate::error::{Error, Result};
use crate::types::*;
const DEFAULT_BASE_URL: &str = "https://reddit-insights.com";
const DEFAULT_TIMEOUT_SECS: u64 = 30;
pub struct RedditInsightsClient {
api_key: String,
base_url: String,
client: Client,
}
impl RedditInsightsClient {
pub fn new(api_key: impl Into<String>) -> Result<Self> {
Self::with_base_url(api_key, DEFAULT_BASE_URL)
}
pub fn with_base_url(api_key: impl Into<String>, base_url: impl Into<String>) -> Result<Self> {
let api_key = api_key.into();
let mut headers = header::HeaderMap::new();
headers.insert(
header::AUTHORIZATION,
header::HeaderValue::from_str(&format!("Bearer {}", api_key))
.map_err(|e| Error::Request(e.to_string()))?,
);
headers.insert(
header::USER_AGENT,
header::HeaderValue::from_static("reddit-insights-rust/1.0.0"),
);
let client = Client::builder()
.default_headers(headers)
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.build()?;
Ok(Self {
api_key,
base_url: base_url.into().trim_end_matches('/').to_string(),
client,
})
}
async fn request<T, B>(&self, method: reqwest::Method, endpoint: &str, body: Option<&B>) -> Result<T>
where
T: DeserializeOwned,
B: Serialize + ?Sized,
{
let url = format!("{}{}", self.base_url, endpoint);
let mut req = self.client.request(method, &url);
if let Some(b) = body {
req = req.json(b);
}
let resp = req.send().await?;
let status = resp.status();
let text = resp.text().await?;
self.handle_response(status.as_u16(), &text)
}
fn handle_response<T: DeserializeOwned>(&self, status: u16, body: &str) -> Result<T> {
if status >= 200 && status < 300 {
return serde_json::from_str(body).map_err(Error::from);
}
let error_msg = serde_json::from_str::<serde_json::Value>(body)
.ok()
.and_then(|v| v.get("error").and_then(|e| e.as_str()).map(String::from))
.unwrap_or_else(|| "Unknown error".to_string());
Err(match status {
401 => Error::Authentication(error_msg),
429 => Error::RateLimit(error_msg),
400 => Error::Validation(error_msg),
_ => Error::Api {
status_code: status,
message: error_msg,
},
})
}
pub async fn semantic_search(&self, query: &str, limit: Option<i32>) -> Result<SemanticSearchResponse> {
let body = serde_json::json!({
"query": query,
"limit": limit.unwrap_or(20)
});
self.request(reqwest::Method::POST, "/api/v1/search/semantic", Some(&body)).await
}
pub async fn vector_search(
&self,
query: &str,
limit: Option<i32>,
start_date: Option<&str>,
end_date: Option<&str>,
) -> Result<VectorSearchResponse> {
let mut body = serde_json::json!({
"query": query,
"limit": limit.unwrap_or(30)
});
if let Some(sd) = start_date {
body["start_date"] = serde_json::json!(sd);
}
if let Some(ed) = end_date {
body["end_date"] = serde_json::json!(ed);
}
self.request(reqwest::Method::POST, "/api/v1/search/vector", Some(&body)).await
}
pub async fn get_trends(
&self,
start_date: Option<&str>,
end_date: Option<&str>,
limit: Option<i32>,
) -> Result<TrendsResponse> {
let mut body = serde_json::json!({ "limit": limit.unwrap_or(20) });
if let Some(sd) = start_date {
body["start_date"] = serde_json::json!(sd);
}
if let Some(ed) = end_date {
body["end_date"] = serde_json::json!(ed);
}
self.request(reqwest::Method::POST, "/api/v1/trends", Some(&body)).await
}
pub async fn list_sonars(&self) -> Result<SonarsResponse> {
self.request::<SonarsResponse, ()>(reqwest::Method::GET, "/api/v1/sonars", None).await
}
pub async fn create_sonar(&self, options: &CreateSonarOptions) -> Result<CreateSonarResponse> {
self.request(reqwest::Method::POST, "/api/v1/sonars", Some(options)).await
}
}