1use reqwest::Client;
2use serde::de::DeserializeOwned;
3use std::marker::PhantomData;
4use std::time::{Duration, Instant};
5use tokio::time::interval;
6
7pub const POLL_INTERVAL_MS: u64 = 500;
8pub const POOL_MAX_IDLE_PER_HOST: usize = 1;
9pub const POOL_IDLE_TIMEOUT_SECS: u64 = 90;
10pub const REQUEST_TIMEOUT_MS: u64 = 1000;
11pub const TCP_KEEPALIVE_SECS: u64 = 60;
12
13pub struct JsonPoller<T> {
14 client: Client,
15 url: String,
16 poll_interval: Duration,
17 _phantom: PhantomData<T>,
18}
19
20pub struct JsonPollerBuilder<T> {
21 url: String,
22 poll_interval_ms: u64,
23 pool_max_idle_per_host: usize,
24 pool_idle_timeout_secs: u64,
25 request_timeout_ms: u64,
26 tcp_keepalive_secs: u64,
27 _phantom: PhantomData<T>,
28}
29
30impl<T> JsonPollerBuilder<T> {
31 pub fn new(url: impl Into<String>) -> Self {
32 Self {
33 url: url.into(),
34 poll_interval_ms: POLL_INTERVAL_MS,
35 pool_max_idle_per_host: POOL_MAX_IDLE_PER_HOST,
36 pool_idle_timeout_secs: POOL_IDLE_TIMEOUT_SECS,
37 request_timeout_ms: REQUEST_TIMEOUT_MS,
38 tcp_keepalive_secs: TCP_KEEPALIVE_SECS,
39 _phantom: PhantomData,
40 }
41 }
42
43 pub fn poll_interval_ms(mut self, ms: u64) -> Self {
44 self.poll_interval_ms = ms;
45 self
46 }
47
48 pub fn pool_max_idle_per_host(mut self, max: usize) -> Self {
49 self.pool_max_idle_per_host = max;
50 self
51 }
52
53 pub fn pool_idle_timeout_secs(mut self, secs: u64) -> Self {
54 self.pool_idle_timeout_secs = secs;
55 self
56 }
57
58 pub fn request_timeout_ms(mut self, ms: u64) -> Self {
59 self.request_timeout_ms = ms;
60 self
61 }
62
63 pub fn tcp_keepalive_secs(mut self, secs: u64) -> Self {
64 self.tcp_keepalive_secs = secs;
65 self
66 }
67
68 pub fn build(self) -> Result<JsonPoller<T>, reqwest::Error> {
69 let client = Client::builder()
70 .pool_max_idle_per_host(self.pool_max_idle_per_host)
71 .pool_idle_timeout(Duration::from_secs(self.pool_idle_timeout_secs))
72 .timeout(Duration::from_millis(self.request_timeout_ms))
73 .tcp_keepalive(Duration::from_secs(self.tcp_keepalive_secs))
74 .build()?;
75
76 Ok(JsonPoller {
77 client,
78 url: self.url,
79 poll_interval: Duration::from_millis(self.poll_interval_ms),
80 _phantom: PhantomData,
81 })
82 }
83}
84
85impl<T> JsonPoller<T>
113where
114 T: DeserializeOwned + Send,
115{
116 pub fn builder(url: impl Into<String>) -> JsonPollerBuilder<T> {
117 JsonPollerBuilder::new(url)
118 }
119
120 pub async fn start<F, Fut>(&self, mut on_data: F)
121 where
122 F: FnMut(T, Duration) -> Fut + Send,
123 Fut: std::future::Future<Output = ()> + Send,
124 {
125 let mut interval_timer = interval(self.poll_interval);
126 interval_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
127
128 loop {
129 interval_timer.tick().await;
130 let request_start = Instant::now();
131 match self.fetch().await {
132 Ok(data) => {
133 let elapsed = request_start.elapsed();
134 on_data(data, elapsed).await;
135 }
136 Err(e) => {
137 tracing::error!("Failed to fetch data: {:?}", e);
138 }
139 }
140 }
141 }
142
143 async fn fetch(&self) -> Result<T, Box<dyn std::error::Error>> {
144 let response = self.client.get(&self.url).send().await.map_err(|e| {
145 tracing::error!("Request failed: {:?}", e);
146 e
147 })?;
148
149 let status = response.status();
150 if !status.is_success() {
151 tracing::error!("HTTP error: {}", status);
152 return Err(format!("HTTP {}", status).into());
153 }
154
155 let data = response.json::<T>().await.map_err(|e| {
156 tracing::error!("JSON parse failed: {:?}", e);
157 e
158 })?;
159
160 Ok(data)
161 }
162
163 pub async fn fetch_once(&self) -> Result<T, Box<dyn std::error::Error>> {
164 self.fetch().await
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use serde::Deserialize;
172
173 #[derive(Debug, Deserialize, PartialEq)]
174 struct HttpBinJson {
175 slideshow: Slideshow,
176 }
177
178 #[derive(Debug, Deserialize, PartialEq)]
179 struct Slideshow {
180 author: String,
181 date: String,
182 title: String,
183 slides: Vec<Slide>,
184 }
185
186 #[derive(Debug, Deserialize, PartialEq)]
187 struct Slide {
188 title: String,
189 #[serde(rename = "type")]
190 slide_type: String,
191 #[serde(default)]
192 items: Vec<String>,
193 }
194
195 #[test]
196 fn test_builder_defaults() {
197 let poller = JsonPoller::<HttpBinJson>::builder("https://example.com")
198 .build()
199 .unwrap();
200
201 assert_eq!(
202 poller.poll_interval,
203 Duration::from_millis(POLL_INTERVAL_MS)
204 );
205 assert_eq!(poller.url, "https://example.com");
206 }
207
208 #[test]
209 fn test_builder_custom_config() {
210 let poller = JsonPoller::<HttpBinJson>::builder("https://example.com")
211 .poll_interval_ms(1000)
212 .request_timeout_ms(2000)
213 .build()
214 .unwrap();
215
216 assert_eq!(poller.poll_interval, Duration::from_millis(1000));
217 }
218
219 #[tokio::test]
220 async fn test_http_error() {
221 let poller = JsonPoller::<HttpBinJson>::builder("https://httpbin.org/status/404")
222 .build()
223 .unwrap();
224
225 let result = poller.fetch_once().await;
226 assert!(result.is_err());
227 }
228
229 #[tokio::test]
230 async fn test_invalid_json() {
231 let poller = JsonPoller::<HttpBinJson>::builder("https://httpbin.org/html")
232 .build()
233 .unwrap();
234
235 let result = poller.fetch_once().await;
236 assert!(result.is_err());
237 }
238
239 #[tokio::test]
240 async fn test_fetch_once() {
241 let json_poller = JsonPoller::<HttpBinJson>::builder("https://httpbin.org/json")
242 .build()
243 .unwrap();
244 let data = json_poller.fetch_once().await.unwrap();
245
246 assert_eq!(data.slideshow.author, "Yours Truly");
247 assert_eq!(data.slideshow.title, "Sample Slide Show");
248 }
249}