1use core::{cmp::max, future::Future, ops::ControlFlow, pin::Pin};
2use std::{
3 io::{Cursor, Error as IoError, SeekFrom},
4 net::{IpAddr, Ipv4Addr, Ipv6Addr},
5};
6
7use deadpool::unmanaged::{Pool, PoolError};
8use futures_util::{AsyncRead, AsyncReadExt as _, AsyncSeek, AsyncSeekExt as _};
9
10use crate::{
11 content::{querier::FillError as ContentFillError, Querier as ContentQuerier},
12 header::{
13 parser::ParseError as HeaderParseError, Parser as HeaderParser, Schema as HeaderSchema,
14 HEADER_LEN,
15 },
16 index::{
17 querier::BuildError as IndexBuildError, V4Querier as IndexV4Querier,
18 V6Querier as IndexV6Querier, INDEX_LEN,
19 },
20 record_field::{RecordField, RecordFieldContents},
21 records::{
22 querier::v4_querier::NewError as RecordsV4QuerierNewError,
23 querier::v6_querier::NewError as RecordsV6QuerierNewError,
24 querier::Error as RecordsQueryError, V4Querier as RecordsV4Querier,
25 V6Querier as RecordsV6Querier,
26 },
27};
28
29pub struct Querier<S> {
31 pub header: HeaderSchema,
32 pub index_v4: IndexV4Querier,
33 pub index_v6: Option<IndexV6Querier>,
34 pub records_v4_pool: Pool<RecordsV4Querier<S>>,
35 pub records_v6_pool: Option<Pool<RecordsV6Querier<S>>>,
36 pub content_pool: Pool<ContentQuerier<S>>,
37}
38
39impl<S> core::fmt::Debug for Querier<S>
40where
41 Pool<RecordsV4Querier<S>>: core::fmt::Debug,
42 Option<Pool<RecordsV6Querier<S>>>: core::fmt::Debug,
43 Pool<ContentQuerier<S>>: core::fmt::Debug,
44{
45 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
46 f.debug_struct("Querier")
47 .field("header", &self.header)
48 .field("index_v4", &self.index_v4)
49 .field("index_v6", &self.index_v6)
50 .field("records_v4_pool", &self.records_v4_pool)
51 .field("records_v6_pool", &self.records_v6_pool)
52 .field("content_pool", &self.content_pool)
53 .finish()
54 }
55}
56
57impl<S> Querier<S>
61where
62 S: AsyncSeek + AsyncRead + Unpin,
63{
64 pub async fn new<F>(mut stream_repeater: F, pool_max_size: usize) -> Result<Self, NewError>
65 where
66 F: FnMut() -> Pin<Box<dyn Future<Output = Result<S, IoError>> + Send + 'static>>,
67 {
68 let pool_max_size = max(1, pool_max_size);
69
70 let mut buf = vec![0; 1024 * 8];
71
72 let mut stream = stream_repeater().await.map_err(NewError::OpenFailed)?;
74
75 let header = {
77 let mut parser = HeaderParser::new();
78 let mut n_read = 0;
79 let mut n_parsed = 0;
80 loop {
81 let n = stream
82 .read(&mut buf[n_read..n_read + HEADER_LEN as usize])
83 .await
84 .map_err(NewError::ReadFailed)?;
85
86 if n == 0 {
87 return Err(NewError::ReadOtherError("header parsing is not completed"));
88 }
89
90 n_read += n;
91
92 match parser
93 .parse(&mut Cursor::new(&buf[n_parsed..n_read]))
94 .map_err(NewError::HeaderParseFailed)?
95 {
96 ControlFlow::Continue(n) => {
97 n_parsed += n;
98 continue;
99 }
100 ControlFlow::Break((_n, header)) => {
101 break header;
102 }
103 }
104 }
105 };
106
107 stream
109 .seek(SeekFrom::Start(header.total_size as u64))
110 .await
111 .map_err(NewError::SeekFailed)?;
112 let n = stream
113 .read(&mut buf[..1])
114 .await
115 .map_err(NewError::ReadFailed)?;
116 if n != 0 {
117 return Err(NewError::TotalSizeMissing);
118 }
119
120 let index_v4 = {
122 let mut builder = IndexV4Querier::builder();
123 let mut n_max_appended = INDEX_LEN as usize;
124 stream
125 .seek(SeekFrom::Start(header.v4_index_seek_from_start()))
126 .await
127 .map_err(NewError::ReadFailed)?;
128 loop {
129 let n = stream
130 .read(&mut buf[..])
131 .await
132 .map_err(NewError::ReadFailed)?;
133
134 if n == 0 {
135 return Err(NewError::ReadOtherError(
136 "index_v4 building is not completed",
137 ));
138 }
139
140 if n < n_max_appended {
141 builder.append(&buf[..n]);
142
143 n_max_appended -= n;
144 continue;
145 } else {
146 builder.append(&buf[..n_max_appended]);
147
148 break builder
149 .finish::<IndexV4Querier>()
150 .map_err(NewError::IndexV4BuildFailed)?;
151 }
152 }
153 };
154
155 let mut index_v6 = None;
157 #[allow(clippy::unnecessary_operation)]
158 {
159 if let Some(v6_index_seek_from_start) = header.v6_index_seek_from_start() {
160 let mut builder = IndexV6Querier::builder();
161 let mut n_max_appended = INDEX_LEN as usize;
162 stream
163 .seek(SeekFrom::Start(v6_index_seek_from_start))
164 .await
165 .map_err(NewError::ReadFailed)?;
166 let index_v6_tmp = loop {
167 let n = stream
168 .read(&mut buf[..])
169 .await
170 .map_err(NewError::ReadFailed)?;
171
172 if n == 0 {
173 return Err(NewError::ReadOtherError(
174 "index_v6 building is not completed",
175 ));
176 }
177
178 if n < n_max_appended {
179 builder.append(&buf[..n]);
180
181 n_max_appended -= n;
182 continue;
183 } else {
184 builder.append(&buf[..n_max_appended]);
185
186 break builder
187 .finish::<IndexV6Querier>()
188 .map_err(NewError::IndexV6BuildFailed)?;
189 }
190 };
191 index_v6 = Some(index_v6_tmp);
192 }
193 };
194
195 let records_v4_pool = {
196 let mut pool_objs = vec![];
197
198 for _ in 0..pool_max_size {
199 let mut stream = stream_repeater().await.map_err(NewError::OpenFailed)?;
200
201 stream
202 .seek(SeekFrom::Start(header.total_size as u64))
203 .await
204 .map_err(NewError::SeekFailed)?;
205 let n = stream
206 .read(&mut buf[..1])
207 .await
208 .map_err(NewError::ReadFailed)?;
209 if n != 0 {
210 return Err(NewError::TotalSizeMissing);
211 }
212
213 let pool_obj = RecordsV4Querier::new(stream, header)
215 .map_err(NewError::RecordsV4QuerierNewFailed)?;
216
217 pool_objs.push(pool_obj);
218 }
219
220 Pool::from(pool_objs)
221 };
222
223 let mut records_v6_pool = None;
224 #[allow(clippy::unnecessary_operation)]
225 {
226 if header.has_v6() {
227 let mut pool_objs = vec![];
228
229 for _ in 0..pool_max_size {
230 let mut stream = stream_repeater().await.map_err(NewError::OpenFailed)?;
231
232 stream
233 .seek(SeekFrom::Start(header.total_size as u64))
234 .await
235 .map_err(NewError::SeekFailed)?;
236 let n = stream
237 .read(&mut buf[..1])
238 .await
239 .map_err(NewError::ReadFailed)?;
240 if n != 0 {
241 return Err(NewError::TotalSizeMissing);
242 }
243
244 let pool_obj = RecordsV6Querier::new(stream, header)
246 .map_err(NewError::RecordsV6QuerierNewFailed)?;
247
248 pool_objs.push(pool_obj);
249 }
250
251 records_v6_pool = Some(Pool::from(pool_objs))
252 }
253 };
254
255 let content_pool = {
256 let mut pool_objs = vec![];
257
258 for _ in 0..pool_max_size {
259 let mut stream = stream_repeater().await.map_err(NewError::OpenFailed)?;
260
261 stream
262 .seek(SeekFrom::Start(header.total_size as u64))
263 .await
264 .map_err(NewError::SeekFailed)?;
265 let n = stream
266 .read(&mut buf[..1])
267 .await
268 .map_err(NewError::ReadFailed)?;
269 if n != 0 {
270 return Err(NewError::TotalSizeMissing);
271 }
272
273 let pool_obj = ContentQuerier::new(stream);
275
276 pool_objs.push(pool_obj);
277 }
278
279 Pool::from(pool_objs)
280 };
281
282 Ok(Self {
284 header,
285 index_v4,
286 index_v6,
287 records_v4_pool,
288 records_v6_pool,
289 content_pool,
290 })
291 }
292}
293
294#[derive(Debug)]
296pub enum NewError {
297 OpenFailed(IoError),
298 SeekFailed(IoError),
299 ReadFailed(IoError),
300 ReadOtherError(&'static str),
301 HeaderParseFailed(HeaderParseError),
302 TotalSizeMissing,
303 IndexV4BuildFailed(IndexBuildError),
304 IndexV6BuildFailed(IndexBuildError),
305 RecordsV4QuerierNewFailed(RecordsV4QuerierNewError),
306 RecordsV6QuerierNewFailed(RecordsV6QuerierNewError),
307}
308
309impl core::fmt::Display for NewError {
310 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
311 write!(f, "{self:?}")
312 }
313}
314
315impl std::error::Error for NewError {}
316
317impl<S> Querier<S>
321where
322 S: AsyncSeek + AsyncRead + Unpin,
323{
324 pub async fn lookup(
325 &self,
326 ip: IpAddr,
327 selected_fields: Option<&[RecordField]>,
328 ) -> Result<Option<(IpAddr, IpAddr, RecordFieldContents)>, LookupError> {
329 match ip {
330 IpAddr::V4(ip) => self.lookup_ipv4(ip, selected_fields).await,
331 IpAddr::V6(ip) => self.lookup_ipv6(ip, selected_fields).await,
332 }
333 }
334
335 pub async fn lookup_ipv4(
336 &self,
337 ip: Ipv4Addr,
338 selected_fields: Option<&[RecordField]>,
339 ) -> Result<Option<(IpAddr, IpAddr, RecordFieldContents)>, LookupError> {
340 let position_range = self.index_v4.query(ip);
341
342 if position_range.end == 0 {
343 return Ok(None);
344 }
345
346 let mut records_v4 = self
348 .records_v4_pool
349 .get()
350 .await
351 .map_err(LookupError::PoolGetFailed)?;
352 let (ip_from, ip_to, mut record_field_contents) = match records_v4
353 .query(ip, position_range)
354 .await
355 .map_err(LookupError::RecordsQueryFailed)?
356 {
357 Some(x) => x,
358 None => return Ok(None),
359 };
360
361 if let Some(selected_fields) = selected_fields {
362 record_field_contents.select(selected_fields);
363 }
364
365 let mut content = self
367 .content_pool
368 .get()
369 .await
370 .map_err(LookupError::PoolGetFailed)?;
371
372 content
373 .fill(&mut record_field_contents)
374 .await
375 .map_err(LookupError::ContentFillFailed)?;
376
377 Ok(Some((ip_from, ip_to, record_field_contents)))
378 }
379
380 pub async fn lookup_ipv6(
381 &self,
382 ip: Ipv6Addr,
383 selected_fields: Option<&[RecordField]>,
384 ) -> Result<Option<(IpAddr, IpAddr, RecordFieldContents)>, LookupError> {
385 if let Some(ip) = ip.to_ipv4() {
386 return self.lookup_ipv4(ip, selected_fields).await.map(|x| {
387 x.map(|(ip_from, ip_to, record_field_contents)| {
388 (
389 match ip_from {
390 IpAddr::V4(ip) => ip.to_ipv6_mapped().into(),
391 IpAddr::V6(ip) => {
392 debug_assert!(false, "unreachable");
393 ip.into()
394 }
395 },
396 match ip_to {
397 IpAddr::V4(ip) => ip.to_ipv6_mapped().into(),
398 IpAddr::V6(ip) => {
399 debug_assert!(false, "unreachable");
400 ip.into()
401 }
402 },
403 record_field_contents,
404 )
405 })
406 });
407 }
408
409 let position_range = self
410 .index_v6
411 .as_ref()
412 .map(|x| x.query(ip))
413 .unwrap_or_default();
414
415 if position_range.end == 0 {
416 return Ok(None);
417 }
418
419 let (ip_from, ip_to, mut record_field_contents) = match self.records_v6_pool.as_ref() {
420 Some(records_v6_pool) => {
421 let mut records_v6 = records_v6_pool
423 .get()
424 .await
425 .map_err(LookupError::PoolGetFailed)?;
426
427 match records_v6
428 .query(ip, position_range)
429 .await
430 .map_err(LookupError::RecordsQueryFailed)?
431 {
432 Some(x) => x,
433 None => return Ok(None),
434 }
435 }
436 None => return Ok(None),
437 };
438
439 if let Some(selected_fields) = selected_fields {
440 record_field_contents.select(selected_fields);
441 }
442
443 let mut content = self
445 .content_pool
446 .get()
447 .await
448 .map_err(LookupError::PoolGetFailed)?;
449
450 content
451 .fill(&mut record_field_contents)
452 .await
453 .map_err(LookupError::ContentFillFailed)?;
454
455 Ok(Some((ip_from, ip_to, record_field_contents)))
456 }
457}
458
459#[derive(Debug)]
461pub enum LookupError {
462 PoolGetFailed(PoolError),
463 RecordsQueryFailed(RecordsQueryError),
464 ContentFillFailed(ContentFillError),
465}
466
467impl core::fmt::Display for LookupError {
468 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
469 write!(f, "{self:?}")
470 }
471}
472
473impl std::error::Error for LookupError {}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478
479 use async_compat::Compat;
480 use futures_util::TryFutureExt as _;
481 use tokio::fs::File as TokioFile;
482
483 use crate::test_helper::{ip2location_bin_files, ip2proxy_bin_files};
484
485 #[tokio::test]
486 async fn test_new_pool_max_size() -> Result<(), Box<dyn std::error::Error>> {
487 let path = match ip2location_bin_files().first().cloned() {
488 Some(x) => x,
489 None => return Ok(()),
490 };
491
492 for (pool_max_size, pool_assert_size) in &[(0, 1), (1, 1), (2, 2)] {
493 let q = Querier::new(
494 || Box::pin(TokioFile::open(path.clone()).map_ok(Compat::new)),
495 *pool_max_size,
496 )
497 .await?;
498 assert_eq!(q.content_pool.status().size, *pool_assert_size);
499 }
500
501 Ok(())
502 }
503
504 #[tokio::test]
505 async fn test_new_and_lookup() -> Result<(), Box<dyn std::error::Error>> {
506 let ips: &[IpAddr] = &[
507 Ipv4Addr::new(0, 0, 0, 0).into(),
508 Ipv4Addr::new(255, 255, 255, 255).into(),
509 Ipv6Addr::new(0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0).into(),
510 Ipv6Addr::new(
511 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
512 )
513 .into(),
514 ];
515
516 for path in ip2location_bin_files().iter() {
517 let q = Querier::new(
518 || Box::pin(TokioFile::open(path.clone()).map_ok(Compat::new)),
519 2,
520 )
521 .await?;
522
523 for ip in ips {
524 let ret = q.lookup(*ip, None).await?;
525 assert!(ret.is_none());
526 }
527
528 if path
529 .as_os_str()
530 .to_str()
531 .map(|x| x.contains("/20220329") && x.contains("IPV6.BIN"))
532 == Some(true)
533 {
534 q.lookup(
535 Ipv6Addr::from(58569107296622255421594597096899477504).into(),
536 None,
537 )
538 .await?
539 .unwrap();
540 }
541
542 if path
543 .as_os_str()
544 .to_str()
545 .map(|x| x.contains("/ip2location-sample") && x.contains("/IP-"))
546 == Some(true)
547 {
548 let ret = q.lookup(Ipv4Addr::new(8, 8, 8, 8).into(), None).await?;
549 assert!(ret.is_some());
550 }
551 }
552
553 for path in ip2proxy_bin_files().iter() {
554 let q = Querier::new(
555 || Box::pin(TokioFile::open(path.clone()).map_ok(Compat::new)),
556 2,
557 )
558 .await?;
559
560 for ip in ips {
561 let ret = q.lookup(*ip, None).await?;
562 assert!(ret.is_none());
563 }
564
565 let out = q
566 .lookup(
567 Ipv6Addr::from(58569071813452613185929873510317667680).into(),
568 None,
569 )
570 .await?;
571 assert!(out.is_some());
572 }
573
574 Ok(())
575 }
576}