1use reqwest::Client;
2use serde::de::DeserializeOwned;
3use std::error::Error;
4use std::future::Future;
5use std::marker::PhantomData;
6use std::time::{Duration, Instant};
7use tokio::time::interval;
8
9pub const POLL_INTERVAL_MS: u64 = 500;
10pub const POOL_MAX_IDLE_PER_HOST: usize = 1;
11pub const POOL_IDLE_TIMEOUT_SECS: u64 = 90;
12pub const REQUEST_TIMEOUT_MS: u64 = 1000;
13pub const TCP_KEEPALIVE_SECS: u64 = 60;
14
15pub struct JsonPoller<T> {
16 client: Client,
17 url: String,
18 poll_interval: Duration,
19 _phantom: PhantomData<T>,
20}
21
22pub struct JsonPollerBuilder<T> {
23 url: String,
24 poll_interval_ms: u64,
25 pool_max_idle_per_host: usize,
26 pool_idle_timeout_secs: u64,
27 request_timeout_ms: u64,
28 tcp_keepalive_secs: u64,
29 _phantom: PhantomData<T>,
30}
31
32impl<T> JsonPollerBuilder<T> {
33 pub fn new(url: impl Into<String>) -> Self {
34 Self {
35 url: url.into(),
36 poll_interval_ms: POLL_INTERVAL_MS,
37 pool_max_idle_per_host: POOL_MAX_IDLE_PER_HOST,
38 pool_idle_timeout_secs: POOL_IDLE_TIMEOUT_SECS,
39 request_timeout_ms: REQUEST_TIMEOUT_MS,
40 tcp_keepalive_secs: TCP_KEEPALIVE_SECS,
41 _phantom: PhantomData,
42 }
43 }
44
45 pub fn poll_interval_ms(mut self, ms: u64) -> Self {
46 self.poll_interval_ms = ms;
47 self
48 }
49
50 pub fn pool_max_idle_per_host(mut self, max: usize) -> Self {
51 self.pool_max_idle_per_host = max;
52 self
53 }
54
55 pub fn pool_idle_timeout_secs(mut self, secs: u64) -> Self {
56 self.pool_idle_timeout_secs = secs;
57 self
58 }
59
60 pub fn request_timeout_ms(mut self, ms: u64) -> Self {
61 self.request_timeout_ms = ms;
62 self
63 }
64
65 pub fn tcp_keepalive_secs(mut self, secs: u64) -> Self {
66 self.tcp_keepalive_secs = secs;
67 self
68 }
69
70 pub fn build(self) -> Result<JsonPoller<T>, reqwest::Error> {
71 let client = Client::builder()
72 .pool_max_idle_per_host(self.pool_max_idle_per_host)
73 .pool_idle_timeout(Duration::from_secs(self.pool_idle_timeout_secs))
74 .timeout(Duration::from_millis(self.request_timeout_ms))
75 .tcp_keepalive(Duration::from_secs(self.tcp_keepalive_secs))
76 .build()?;
77
78 Ok(JsonPoller {
79 client,
80 url: self.url,
81 poll_interval: Duration::from_millis(self.poll_interval_ms),
82 _phantom: PhantomData,
83 })
84 }
85}
86
87impl<T> JsonPoller<T>
88where
89 T: DeserializeOwned + Send,
90{
91 pub fn builder(url: impl Into<String>) -> JsonPollerBuilder<T> {
92 JsonPollerBuilder::new(url)
93 }
94
95 pub async fn start<F, Fut, E>(&self, mut on_data: F) -> Result<(), E>
96 where
97 F: FnMut(T, Duration) -> Fut + Send,
98 Fut: Future<Output = Result<(), E>> + Send,
99 E: std::fmt::Debug,
100 {
101 let mut interval_timer = interval(self.poll_interval);
102 interval_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
103
104 loop {
105 interval_timer.tick().await;
106 let request_start = Instant::now();
107 match self.fetch().await {
108 Ok(data) => {
109 let elapsed = request_start.elapsed();
110 on_data(data, elapsed).await?;
111 }
112 Err(e) => {
113 tracing::error!("Failed to fetch data: {:?}", e);
114 continue;
115 }
116 }
117 }
118 }
119
120 async fn fetch(&self) -> Result<T, Box<dyn Error + Send + Sync>> {
121 let response = self.client.get(&self.url).send().await?;
122
123 if !response.status().is_success() {
124 return Err(format!("HTTP {}", response.status()).into());
125 }
126
127 Ok(response.json::<T>().await?)
128 }
129
130 pub async fn fetch_once(&self) -> Result<T, Box<dyn Error + Send + Sync>> {
131 self.fetch().await
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use serde::Deserialize;
139
140 #[derive(Debug, Deserialize, PartialEq)]
141 struct HttpBinJson {
142 slideshow: Slideshow,
143 }
144
145 #[derive(Debug, Deserialize, PartialEq)]
146 struct Slideshow {
147 author: String,
148 date: String,
149 title: String,
150 slides: Vec<Slide>,
151 }
152
153 #[derive(Debug, Deserialize, PartialEq)]
154 struct Slide {
155 title: String,
156 #[serde(rename = "type")]
157 slide_type: String,
158 #[serde(default)]
159 items: Vec<String>,
160 }
161
162 #[test]
163 fn test_builder_defaults() {
164 let poller = JsonPoller::<HttpBinJson>::builder("https://example.com")
165 .build()
166 .unwrap();
167
168 assert_eq!(
169 poller.poll_interval,
170 Duration::from_millis(POLL_INTERVAL_MS)
171 );
172 assert_eq!(poller.url, "https://example.com");
173 }
174
175 #[test]
176 fn test_builder_custom_config() {
177 let poller = JsonPoller::<HttpBinJson>::builder("https://example.com")
178 .poll_interval_ms(1000)
179 .request_timeout_ms(2000)
180 .build()
181 .unwrap();
182
183 assert_eq!(poller.poll_interval, Duration::from_millis(1000));
184 }
185
186 #[tokio::test]
187 async fn test_http_error() {
188 let poller = JsonPoller::<HttpBinJson>::builder("https://httpbin.org/status/404")
189 .build()
190 .unwrap();
191
192 let result = poller.fetch_once().await;
193 assert!(result.is_err());
194 }
195
196 #[tokio::test]
197 async fn test_invalid_json() {
198 let poller = JsonPoller::<HttpBinJson>::builder("https://httpbin.org/html")
199 .build()
200 .unwrap();
201
202 let result = poller.fetch_once().await;
203 assert!(result.is_err());
204 }
205
206 #[tokio::test]
207 async fn test_fetch_once() {
208 let json_poller = JsonPoller::<HttpBinJson>::builder("https://httpbin.org/json")
209 .build()
210 .unwrap();
211 let data = json_poller.fetch_once().await.unwrap();
212
213 assert_eq!(data.slideshow.author, "Yours Truly");
214 assert_eq!(data.slideshow.title, "Sample Slide Show");
215 }
216}