1use std::error::Error;
2use reqwest::Client;
3use serde::de::DeserializeOwned;
4use std::future::Future;
5use std::marker::PhantomData;
6use std::time::{Duration, Instant};
7use tokio::time::interval;
8
9pub const POLL_INTERVAL_MS: u64 = 500;
10pub const POOL_MAX_IDLE_PER_HOST: usize = 1;
11pub const POOL_IDLE_TIMEOUT_SECS: u64 = 90;
12pub const REQUEST_TIMEOUT_MS: u64 = 1000;
13pub const TCP_KEEPALIVE_SECS: u64 = 60;
14
15pub struct JsonPoller<T> {
16 client: Client,
17 url: String,
18 poll_interval: Duration,
19 _phantom: PhantomData<T>,
20}
21
22pub struct JsonPollerBuilder<T> {
23 url: String,
24 poll_interval_ms: u64,
25 pool_max_idle_per_host: usize,
26 pool_idle_timeout_secs: u64,
27 request_timeout_ms: u64,
28 tcp_keepalive_secs: u64,
29 _phantom: PhantomData<T>,
30}
31
32impl<T> JsonPollerBuilder<T> {
33 pub fn new(url: impl Into<String>) -> Self {
34 Self {
35 url: url.into(),
36 poll_interval_ms: POLL_INTERVAL_MS,
37 pool_max_idle_per_host: POOL_MAX_IDLE_PER_HOST,
38 pool_idle_timeout_secs: POOL_IDLE_TIMEOUT_SECS,
39 request_timeout_ms: REQUEST_TIMEOUT_MS,
40 tcp_keepalive_secs: TCP_KEEPALIVE_SECS,
41 _phantom: PhantomData,
42 }
43 }
44
45 pub fn poll_interval_ms(mut self, ms: u64) -> Self {
46 self.poll_interval_ms = ms;
47 self
48 }
49
50 pub fn pool_max_idle_per_host(mut self, max: usize) -> Self {
51 self.pool_max_idle_per_host = max;
52 self
53 }
54
55 pub fn pool_idle_timeout_secs(mut self, secs: u64) -> Self {
56 self.pool_idle_timeout_secs = secs;
57 self
58 }
59
60 pub fn request_timeout_ms(mut self, ms: u64) -> Self {
61 self.request_timeout_ms = ms;
62 self
63 }
64
65 pub fn tcp_keepalive_secs(mut self, secs: u64) -> Self {
66 self.tcp_keepalive_secs = secs;
67 self
68 }
69
70 pub fn build(self) -> Result<JsonPoller<T>, reqwest::Error> {
71 let client = Client::builder()
72 .pool_max_idle_per_host(self.pool_max_idle_per_host)
73 .pool_idle_timeout(Duration::from_secs(self.pool_idle_timeout_secs))
74 .timeout(Duration::from_millis(self.request_timeout_ms))
75 .tcp_keepalive(Duration::from_secs(self.tcp_keepalive_secs))
76 .build()?;
77
78 Ok(JsonPoller {
79 client,
80 url: self.url,
81 poll_interval: Duration::from_millis(self.poll_interval_ms),
82 _phantom: PhantomData,
83 })
84 }
85}
86
87impl<T> JsonPoller<T>
88where
89 T: DeserializeOwned + Send,
90{
91 pub fn builder(url: impl Into<String>) -> JsonPollerBuilder<T> {
92 JsonPollerBuilder::new(url)
93 }
94
95 pub async fn start<F, Fut>(&self, mut on_data: F)
96 where
97 F: FnMut(T, Duration) -> Fut + Send,
98 Fut: Future<Output = ()> + Send,
99 {
100 let mut interval_timer = interval(self.poll_interval);
101 interval_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
102
103 loop {
104 interval_timer.tick().await;
105 let request_start = Instant::now();
106 match self.fetch().await {
107 Ok(data) => {
108 let elapsed = request_start.elapsed();
109 on_data(data, elapsed).await;
110 }
111 Err(e) => {
112 tracing::error!("Failed to fetch data: {:?}", e);
113 }
114 }
115 }
116 }
117
118 async fn fetch(&self) -> Result<T, Box<dyn Error + Send + Sync>> {
119 let response = self.client.get(&self.url).send().await.map_err(|e| {
120 tracing::error!("Request failed: {:?}", e);
121 Box::new(e) as Box<dyn Error + Send + Sync>
122 })?;
123
124 let status = response.status();
125 if !status.is_success() {
126 tracing::error!("HTTP error: {}", status);
127 return Err(format!("HTTP {}", status).into());
128 }
129
130 let data = response.json::<T>().await.map_err(|e| {
131 tracing::error!("JSON parse failed: {:?}", e);
132 Box::new(e) as Box<dyn Error + Send + Sync>
133 })?;
134
135 Ok(data)
136 }
137
138 pub async fn fetch_once(&self) -> Result<T, Box<dyn std::error::Error + Send + Sync>> {
139 self.fetch().await
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use serde::Deserialize;
147
148 #[derive(Debug, Deserialize, PartialEq)]
149 struct HttpBinJson {
150 slideshow: Slideshow,
151 }
152
153 #[derive(Debug, Deserialize, PartialEq)]
154 struct Slideshow {
155 author: String,
156 date: String,
157 title: String,
158 slides: Vec<Slide>,
159 }
160
161 #[derive(Debug, Deserialize, PartialEq)]
162 struct Slide {
163 title: String,
164 #[serde(rename = "type")]
165 slide_type: String,
166 #[serde(default)]
167 items: Vec<String>,
168 }
169
170 #[test]
171 fn test_builder_defaults() {
172 let poller = JsonPoller::<HttpBinJson>::builder("https://example.com")
173 .build()
174 .unwrap();
175
176 assert_eq!(
177 poller.poll_interval,
178 Duration::from_millis(POLL_INTERVAL_MS)
179 );
180 assert_eq!(poller.url, "https://example.com");
181 }
182
183 #[test]
184 fn test_builder_custom_config() {
185 let poller = JsonPoller::<HttpBinJson>::builder("https://example.com")
186 .poll_interval_ms(1000)
187 .request_timeout_ms(2000)
188 .build()
189 .unwrap();
190
191 assert_eq!(poller.poll_interval, Duration::from_millis(1000));
192 }
193
194 #[tokio::test]
195 async fn test_http_error() {
196 let poller = JsonPoller::<HttpBinJson>::builder("https://httpbin.org/status/404")
197 .build()
198 .unwrap();
199
200 let result = poller.fetch_once().await;
201 assert!(result.is_err());
202 }
203
204 #[tokio::test]
205 async fn test_invalid_json() {
206 let poller = JsonPoller::<HttpBinJson>::builder("https://httpbin.org/html")
207 .build()
208 .unwrap();
209
210 let result = poller.fetch_once().await;
211 assert!(result.is_err());
212 }
213
214 #[tokio::test]
215 async fn test_fetch_once() {
216 let json_poller = JsonPoller::<HttpBinJson>::builder("https://httpbin.org/json")
217 .build()
218 .unwrap();
219 let data = json_poller.fetch_once().await.unwrap();
220
221 assert_eq!(data.slideshow.author, "Yours Truly");
222 assert_eq!(data.slideshow.title, "Sample Slide Show");
223 }
224}