1use std::time::Duration;
2
3use serde::de::DeserializeOwned;
4use serde_json::{json, Value};
5
6use crate::error::{Error, Result};
7use crate::sessions::Sessions;
8use crate::types::{
9 AgenticSearchResult, CrawlResult, Document, MapResult, SearchResult, WireResult,
10};
11
12pub const VERSION: &str = "0.1.0";
14
15const DEFAULT_BASE_URL: &str = "https://api.anakin.io/v1";
16
17#[derive(Clone)]
33pub struct Client {
34 http: reqwest::Client,
35 api_key: String,
36 base_url: String,
37 max_retries: u32,
38 poll_interval: Duration,
39 poll_max_interval: Duration,
40 poll_timeout: Duration,
41}
42
43impl std::fmt::Debug for Client {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 f.debug_struct("Client")
46 .field("base_url", &self.base_url)
47 .field("max_retries", &self.max_retries)
48 .field("poll_interval", &self.poll_interval)
49 .field("poll_max_interval", &self.poll_max_interval)
50 .field("poll_timeout", &self.poll_timeout)
51 .finish_non_exhaustive()
52 }
53}
54
55impl Client {
56 pub fn builder() -> ClientBuilder {
58 ClientBuilder::default()
59 }
60
61 pub fn sessions(&self) -> Sessions<'_> {
63 Sessions::new(self)
64 }
65
66 pub async fn scrape(&self, url: &str) -> Result<Document> {
70 self.scrape_with(url, None).await
71 }
72
73 pub async fn scrape_with(&self, url: &str, opts: Option<Value>) -> Result<Document> {
75 let body = build_body(json!({ "url": url }), opts);
76 let submit: Value = self.send_json(reqwest::Method::POST, "/url-scraper", Some(body)).await?;
77 let job_id = require_string(&submit, "job_id")?;
78 let poll = self.poll_job(&format!("/url-scraper/{job_id}")).await?;
79 decode_field(&poll, "result")
80 }
81
82 pub async fn map(&self, url: &str) -> Result<MapResult> {
84 self.map_with(url, None).await
85 }
86
87 pub async fn map_with(&self, url: &str, opts: Option<Value>) -> Result<MapResult> {
88 let body = build_body(json!({ "url": url }), opts);
89 let submit: Value = self.send_json(reqwest::Method::POST, "/map", Some(body)).await?;
90 let job_id = require_string(&submit, "job_id")?;
91 let poll = self.poll_job(&format!("/map/{job_id}")).await?;
92 decode_field(&poll, "result")
93 }
94
95 pub async fn crawl(&self, url: &str) -> Result<CrawlResult> {
97 self.crawl_with(url, None).await
98 }
99
100 pub async fn crawl_with(&self, url: &str, opts: Option<Value>) -> Result<CrawlResult> {
101 let body = build_body(json!({ "url": url }), opts);
102 let submit: Value = self.send_json(reqwest::Method::POST, "/crawl", Some(body)).await?;
103 let job_id = require_string(&submit, "job_id")?;
104 let poll = self.poll_job(&format!("/crawl/{job_id}")).await?;
105 decode_field(&poll, "result")
106 }
107
108 pub async fn search(&self, query: &str) -> Result<SearchResult> {
110 self.search_with(query, None).await
111 }
112
113 pub async fn search_with(&self, query: &str, opts: Option<Value>) -> Result<SearchResult> {
114 let body = build_body(json!({ "prompt": query }), opts);
115 let v: Value = self.send_json(reqwest::Method::POST, "/search", Some(body)).await?;
116 serde_json::from_value(v).map_err(|e| Error::Other(format!("decode response: {e}")))
117 }
118
119 pub async fn agentic_search(&self, prompt: &str) -> Result<AgenticSearchResult> {
121 self.agentic_search_with(prompt, None).await
122 }
123
124 pub async fn agentic_search_with(&self, prompt: &str, opts: Option<Value>) -> Result<AgenticSearchResult> {
125 let body = build_body(json!({ "prompt": prompt }), opts);
126 let submit: Value = self.send_json(reqwest::Method::POST, "/agentic-search", Some(body)).await?;
127 let job_id = require_string(&submit, "job_id")?;
128 let poll = self.poll_job(&format!("/agentic-search/{job_id}")).await?;
129 decode_field(&poll, "result")
130 }
131
132 pub async fn wire(&self, action_id: &str, params: Option<Value>) -> Result<WireResult> {
134 let mut body = json!({ "action_id": action_id });
135 if let Some(p) = params {
136 body["params"] = p;
137 }
138 let submit: Value = self.send_json(reqwest::Method::POST, "/holocron/task", Some(body)).await?;
139 let job_id = require_string(&submit, "job_id")?;
140 let poll = self.poll_job(&format!("/holocron/task/{job_id}")).await?;
141 decode_field(&poll, "result")
142 }
143
144 pub(crate) async fn send_json(
147 &self,
148 method: reqwest::Method,
149 path: &str,
150 body: Option<Value>,
151 ) -> Result<Value> {
152 let url = format!("{}{}", self.base_url, path);
153 let mut last_resp: Option<reqwest::Response> = None;
154
155 for attempt in 0..=self.max_retries {
156 if attempt > 0 {
157 let delay = backoff(attempt, last_resp.as_ref());
158 tokio::time::sleep(delay).await;
159 }
160
161 let mut req = self
162 .http
163 .request(method.clone(), &url)
164 .header("X-API-Key", &self.api_key)
165 .header("Accept", "application/json")
166 .header("User-Agent", format!("anakin-rust/{VERSION}"));
167 if let Some(ref b) = body {
168 req = req.json(b);
169 }
170
171 let resp = match req.send().await {
172 Ok(r) => r,
173 Err(e) => {
174 if attempt == self.max_retries {
175 return Err(Error::Network {
176 message: format!(
177 "http request after {} retries: {}",
178 self.max_retries, e
179 ),
180 source: Some(Box::new(e)),
181 });
182 }
183 continue;
184 }
185 };
186
187 let status = resp.status().as_u16();
188 if should_retry(status) && attempt < self.max_retries {
189 last_resp = Some(resp);
190 continue;
191 }
192 return self.handle_response(resp).await;
193 }
194
195 Err(Error::Other("retry loop exited unexpectedly".into()))
198 }
199
200 async fn handle_response(&self, resp: reqwest::Response) -> Result<Value> {
201 let status = resp.status();
202 let retry_after_header = resp
203 .headers()
204 .get("Retry-After")
205 .and_then(|v| v.to_str().ok())
206 .map(|s| s.to_string());
207
208 if status.is_success() {
209 let bytes = resp.bytes().await.map_err(|e| Error::Network {
210 message: format!("read response: {e}"),
211 source: Some(Box::new(e)),
212 })?;
213 if bytes.is_empty() {
214 return Ok(Value::Object(serde_json::Map::new()));
215 }
216 return serde_json::from_slice(&bytes)
217 .map_err(|e| Error::Other(format!("decode response: {e}")));
218 }
219
220 let body_bytes = resp.bytes().await.unwrap_or_default();
221 Err(map_error(status.as_u16(), retry_after_header.as_deref(), &body_bytes))
222 }
223
224 async fn poll_job(&self, path: &str) -> Result<Value> {
225 let deadline = std::time::Instant::now() + self.poll_timeout;
226 let mut delay = self.poll_interval;
227
228 loop {
229 let v = self.send_json(reqwest::Method::GET, path, None).await?;
230 let status = v.get("status").and_then(|s| s.as_str()).unwrap_or("");
231 let error = v.get("error").and_then(|s| s.as_str()).unwrap_or("");
232 let job_id = v.get("job_id").and_then(|s| s.as_str()).map(|s| s.to_string());
233
234 if status == "completed" || status == "succeeded" {
235 return Ok(v);
236 }
237 if status == "failed" {
238 return Err(Error::JobFailed {
239 job_id,
240 reason: error.to_string(),
241 });
242 }
243
244 if std::time::Instant::now() > deadline {
245 return Err(Error::JobTimeout {
246 job_id,
247 elapsed: self.poll_timeout,
248 });
249 }
250
251 tokio::time::sleep(delay).await;
252 let next_ms = (delay.as_millis() as f64 * 1.5) as u64;
253 let capped = std::cmp::min(next_ms, self.poll_max_interval.as_millis() as u64);
254 delay = Duration::from_millis(capped);
255 }
256 }
257}
258
259fn build_body(base: Value, extra: Option<Value>) -> Value {
262 let mut obj = base.as_object().cloned().unwrap_or_default();
263 if let Some(Value::Object(map)) = extra {
264 for (k, v) in map {
265 obj.insert(k, v);
266 }
267 }
268 Value::Object(obj)
269}
270
271fn should_retry(status: u16) -> bool {
272 status == 429 || (500..600).contains(&status)
273}
274
275fn backoff(attempt: u32, prev: Option<&reqwest::Response>) -> Duration {
276 if let Some(r) = prev {
277 if let Some(ra) = r
278 .headers()
279 .get("Retry-After")
280 .and_then(|v| v.to_str().ok())
281 .and_then(|s| s.parse::<u64>().ok())
282 {
283 if ra > 0 {
284 return Duration::from_secs(ra);
285 }
286 }
287 }
288 let ms = (2_u64.saturating_pow(attempt.saturating_sub(1))) * 500;
289 Duration::from_millis(std::cmp::min(ms, 30_000))
290}
291
292fn parse_retry_after(header: Option<&str>) -> Duration {
293 match header {
294 Some(s) => match s.trim().parse::<u64>() {
295 Ok(n) => Duration::from_secs(n),
296 Err(_) => Duration::ZERO,
297 },
298 None => Duration::ZERO,
299 }
300}
301
302fn map_error(status: u16, retry_after_header: Option<&str>, body: &[u8]) -> Error {
303 let parsed: Value = serde_json::from_slice(body).unwrap_or(Value::Null);
304 let message = parsed
305 .get("error")
306 .and_then(|v| v.as_str())
307 .unwrap_or("")
308 .to_string();
309 let code = parsed
310 .get("code")
311 .and_then(|v| v.as_str())
312 .map(|s| s.to_string());
313 let balance = parsed.get("balance").and_then(|v| v.as_i64()).unwrap_or(0);
314 let required = parsed.get("required").and_then(|v| v.as_i64()).unwrap_or(0);
315 let message = if message.is_empty() {
316 format!("HTTP {status}")
317 } else {
318 message
319 };
320
321 match status {
322 400 => Error::InvalidRequest { message, status, code },
323 401 => Error::Authentication { message, status, code },
324 402 => Error::InsufficientCredits { message, status, code, balance, required },
325 429 => Error::RateLimit {
326 message,
327 status,
328 code,
329 retry_after: parse_retry_after(retry_after_header),
330 },
331 s if s >= 500 => Error::Server { message, status, code },
332 _ => Error::Other(format!("HTTP {status}: {message}")),
333 }
334}
335
336fn require_string(v: &Value, field: &str) -> Result<String> {
337 v.get(field)
338 .and_then(|x| x.as_str())
339 .filter(|s| !s.is_empty())
340 .map(|s| s.to_string())
341 .ok_or_else(|| Error::Other(format!("API response missing required field: {field}")))
342}
343
344fn decode_field<T: DeserializeOwned + Default>(parent: &Value, field: &str) -> Result<T> {
345 match parent.get(field) {
346 Some(v) if !v.is_null() => serde_json::from_value(v.clone())
347 .map_err(|e| Error::Other(format!("decode response: {e}"))),
348 _ => Ok(T::default()),
349 }
350}
351
352pub struct ClientBuilder {
357 api_key: Option<String>,
358 base_url: Option<String>,
359 timeout: Duration,
360 max_retries: u32,
361 poll_interval: Duration,
362 poll_max_interval: Duration,
363 poll_timeout: Duration,
364 http: Option<reqwest::Client>,
365}
366
367impl Default for ClientBuilder {
368 fn default() -> Self {
369 Self {
370 api_key: None,
371 base_url: None,
372 timeout: Duration::from_secs(60),
373 max_retries: 4,
374 poll_interval: Duration::from_secs(1),
375 poll_max_interval: Duration::from_secs(10),
376 poll_timeout: Duration::from_secs(300),
377 http: None,
378 }
379 }
380}
381
382impl ClientBuilder {
383 pub fn api_key(mut self, key: impl Into<String>) -> Self {
384 self.api_key = Some(key.into());
385 self
386 }
387
388 pub fn base_url(mut self, url: impl Into<String>) -> Self {
389 self.base_url = Some(url.into());
390 self
391 }
392
393 pub fn timeout(mut self, t: Duration) -> Self {
394 self.timeout = t;
395 self
396 }
397
398 pub fn max_retries(mut self, n: u32) -> Self {
399 self.max_retries = n;
400 self
401 }
402
403 pub fn poll_interval(mut self, d: Duration) -> Self {
404 self.poll_interval = d;
405 self
406 }
407
408 pub fn poll_max_interval(mut self, d: Duration) -> Self {
409 self.poll_max_interval = d;
410 self
411 }
412
413 pub fn poll_timeout(mut self, d: Duration) -> Self {
414 self.poll_timeout = d;
415 self
416 }
417
418 pub fn http_client(mut self, c: reqwest::Client) -> Self {
420 self.http = Some(c);
421 self
422 }
423
424 pub fn build(self) -> Result<Client> {
425 let api_key = self
426 .api_key
427 .or_else(|| std::env::var("ANAKIN_API_KEY").ok())
428 .filter(|s| !s.is_empty())
429 .ok_or_else(|| {
430 Error::Other(
431 "no API key — call .api_key(...) on the builder or set ANAKIN_API_KEY".into(),
432 )
433 })?;
434
435 let http = match self.http {
436 Some(c) => c,
437 None => reqwest::Client::builder()
438 .timeout(self.timeout)
439 .user_agent(format!("anakin-rust/{VERSION}"))
440 .build()
441 .map_err(|e| Error::Other(format!("build http client: {e}")))?,
442 };
443
444 Ok(Client {
445 http,
446 api_key,
447 base_url: self
448 .base_url
449 .unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
450 .trim_end_matches('/')
451 .to_string(),
452 max_retries: self.max_retries,
453 poll_interval: self.poll_interval,
454 poll_max_interval: self.poll_max_interval,
455 poll_timeout: self.poll_timeout,
456 })
457 }
458}