qdrant_datafusion/
test_utils.rs1use std::collections::VecDeque;
2use std::env;
3use std::str::FromStr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use testcontainers::core::wait::LogWaitStrategy;
8use testcontainers::core::{IntoContainerPort, Mount};
9use testcontainers::runners::AsyncRunner;
10use testcontainers::{ContainerAsync, GenericImage, ImageExt, TestcontainersError};
11use tokio::sync::RwLock;
12use tokio::time::sleep;
13use tracing::level_filters::LevelFilter;
14use tracing::{debug, error};
15use tracing_subscriber::EnvFilter;
16use tracing_subscriber::prelude::*;
17
18pub const ENDPOINT_ENV: &str = "QDRANT_ENDPOINT";
19pub const VERSION_ENV: &str = "QDRANT_VERSION";
20pub const REST_PORT_ENV: &str = "QDRANT_NATIVE_PORT";
21pub const GRPC_PORT_ENV: &str = "QDRANT_HTTP_PORT";
22pub const API_KEY_ENV: &str = "QDRANT_API_KEY";
23pub const QDRANT_API_KEY_ENV: &str = "QDRANT__SERVICE__API_KEY";
24
25pub const QDRANT_VERSION: &str = "latest";
26pub const QDRANT_REST_PORT: u16 = 6333;
27pub const QDRANT_GRPC_PORT: u16 = 6334;
28pub const QDRANT_ENDPOINT: &str = "localhost";
29pub const QDRANT_CONFIG_SRC: &str = "tests/bin/";
30pub const QDRANT_CONFIG_DEST: &str = "/qdrant/config/config.yaml";
31pub const QDRANT_API_KEY: &str = "qdrant-datafusion-api-key";
32
33pub fn init_tracing(directives: Option<&[(&str, &str)]>) {
35 let rust_log = env::var("RUST_LOG").unwrap_or_default();
36
37 let stdio_logger = tracing_subscriber::fmt::Layer::default()
38 .with_level(true)
39 .with_file(true)
40 .with_line_number(true)
41 .with_filter(get_filter(&rust_log, directives));
42
43 if tracing::subscriber::set_global_default(tracing_subscriber::registry().with(stdio_logger))
45 .is_ok()
46 {
47 debug!("Tracing initialized with RUST_LOG={rust_log}");
48 }
49}
50
51#[allow(unused)]
55pub fn get_filter(rust_log: &str, directives: Option<&[(&str, &str)]>) -> EnvFilter {
56 let mut env_dirs = vec![];
57 let level = if rust_log.is_empty() {
58 LevelFilter::WARN.to_string()
59 } else if let Ok(level) = LevelFilter::from_str(rust_log) {
60 level.to_string()
61 } else {
62 let mut parts = rust_log.split(',');
63 let level = parts.next().and_then(|p| LevelFilter::from_str(p).ok());
64 env_dirs = parts
65 .map(|s| s.split('=').collect::<VecDeque<_>>())
66 .filter(|s| s.len() == 2)
67 .map(|mut s| (s.pop_front().unwrap(), s.pop_front().unwrap()))
68 .collect::<Vec<_>>();
69 level.unwrap_or(LevelFilter::WARN).to_string()
70 };
71
72 let mut filter = EnvFilter::new(level)
73 .add_directive("ureq=info".parse().unwrap())
74 .add_directive("tokio=info".parse().unwrap())
75 .add_directive("runtime=error".parse().unwrap())
76 .add_directive("opentelemetry_sdk=off".parse().unwrap());
77
78 if let Some(directives) = directives {
79 for (key, value) in directives {
80 filter = filter.add_directive(format!("{key}={value}").parse().unwrap());
81 }
82 }
83
84 for (key, value) in env_dirs {
85 filter = filter.add_directive(format!("{key}={value}").parse().unwrap());
86 }
87
88 filter
89}
90
91pub async fn create_container(conf: Option<&str>) -> Arc<QdrantContainer> {
94 let c = QdrantContainer::try_new(conf).await.expect("Failed to initialize Qdrant container");
95 Arc::new(c)
96}
97
98pub struct QdrantContainer {
99 pub endpoint: String,
100 pub rest_port: u16,
101 pub grpc_port: u16,
102 pub api_key: String,
103 container: RwLock<Option<ContainerAsync<GenericImage>>>,
104}
105
106impl QdrantContainer {
107 pub async fn try_new(conf: Option<&str>) -> Result<Self, TestcontainersError> {
109 let version = env::var(VERSION_ENV).unwrap_or(QDRANT_VERSION.to_string());
111 let rest_port = env::var(REST_PORT_ENV)
112 .ok()
113 .and_then(|p| p.parse::<u16>().ok())
114 .unwrap_or(QDRANT_REST_PORT);
115 let grpc_port = env::var(GRPC_PORT_ENV)
116 .ok()
117 .and_then(|p| p.parse::<u16>().ok())
118 .unwrap_or(QDRANT_GRPC_PORT);
119 let api_key = env::var(API_KEY_ENV).ok().unwrap_or(QDRANT_API_KEY.into());
120
121 let image = GenericImage::new("qdrant/qdrant", &version)
123 .with_exposed_port(rest_port.tcp())
124 .with_exposed_port(grpc_port.tcp())
125 .with_wait_for(testcontainers::core::WaitFor::Log(LogWaitStrategy::stdout_or_stderr(
126 "Qdrant gRPC listening",
127 )))
128 .with_env_var(QDRANT_API_KEY_ENV, &api_key)
129 .with_mount(Mount::bind_mount(
130 format!(
131 "{}/{QDRANT_CONFIG_SRC}/{}",
132 env!("CARGO_MANIFEST_DIR"),
133 conf.unwrap_or("config.yaml")
134 ),
135 QDRANT_CONFIG_DEST,
136 ));
137
138 let container = image.start().await?;
140
141 let rest_port = container.get_host_port_ipv4(rest_port).await?;
143 let grpc_port = container.get_host_port_ipv4(grpc_port).await?;
144
145 let endpoint = env::var(ENDPOINT_ENV).unwrap_or(QDRANT_ENDPOINT.to_string());
147
148 sleep(Duration::from_secs(2)).await;
150
151 let container = RwLock::new(Some(container));
152 Ok(QdrantContainer {
153 endpoint,
154 rest_port,
155 grpc_port,
156 api_key: api_key.to_string(),
157 container,
158 })
159 }
160
161 pub fn get_url(&self) -> String { format!("http://{}:{}", self.endpoint, self.grpc_port) }
162
163 pub fn get_api_key(&self) -> &str { &self.api_key }
164
165 pub async fn shutdown(&self) -> Result<(), TestcontainersError> {
167 let mut container = self.container.write().await;
168 if let Some(container) = container.take() {
169 let _ = container
170 .stop_with_timeout(Some(0))
171 .await
172 .inspect_err(|error| {
173 error!(?error, "Failed to stop container, will attempt to remove");
174 })
175 .ok();
176 let _ = container
177 .rm()
178 .await
179 .inspect_err(|error| {
180 error!(?error, "Failed to rm container, cleanup manually");
181 })
182 .ok();
183 }
184 Ok(())
185 }
186}