1use crate::collection::Collection;
18use crate::config::RPC_TIMEOUT;
19use crate::error::{Error, Result};
20use crate::options::CreateCollectionOptions;
21pub use crate::proto::common::ConsistencyLevel;
22use crate::proto::common::MsgType;
23use crate::proto::milvus::milvus_service_client::MilvusServiceClient;
24use crate::proto::milvus::{
25 CreateCollectionRequest, DescribeCollectionRequest, DropCollectionRequest, FlushRequest,
26 HasCollectionRequest, ShowCollectionsRequest,
27};
28use crate::schema::CollectionSchema;
29use crate::utils::{new_msg, status_to_result};
30use base64::Engine;
31use base64::engine::general_purpose;
32use prost::bytes::BytesMut;
33use prost::Message;
34use tonic::{Request};
35use tonic::service::Interceptor;
36use std::collections::HashMap;
37use std::convert::TryInto;
38use std::time::Duration;
39use tonic::codegen::{StdError, InterceptedService};
40use tonic::transport::Channel;
41
42#[derive(Clone)]
43pub struct AuthInterceptor {
44 token: Option<String>,
45}
46
47impl Interceptor for AuthInterceptor {
48 fn call(&mut self, mut req: Request<()>) -> std::result::Result<tonic::Request<()>, tonic::Status> {
49 if let Some(ref token) = self.token {
50 let header_value = format!("{}", token);
51 req.metadata_mut()
52 .insert("authorization", header_value.parse().unwrap());
53 }
54
55 Ok(req)
56 }
57}
58
59#[derive(Clone)]
60pub struct ClientBuilder<D> {
61 dst: D,
62 username: Option<String>,
63 password: Option<String>,
64}
65
66impl<D> ClientBuilder<D>
67where
68 D: TryInto<tonic::transport::Endpoint> + Clone,
69 D::Error: Into<StdError>,
70 D::Error: std::fmt::Debug,
71{
72 pub fn new(dst: D) -> Self {
73 Self {
74 dst,
75 username: None,
76 password: None,
77 }
78 }
79
80 pub fn username(mut self, username: &str) -> Self {
81 self.username = Some(username.to_owned());
82 self
83 }
84
85 pub fn password(mut self, password: &str) -> Self {
86 self.password = Some(password.to_owned());
87 self
88 }
89
90 pub async fn build(self) -> Result<Client> {
91 Client::with_timeout(self.dst, RPC_TIMEOUT, self.username, self.password).await
92 }
93}
94
95#[derive(Clone)]
96pub struct Client {
97 client: MilvusServiceClient<InterceptedService<Channel, AuthInterceptor>>,
98}
99
100impl Client {
101 pub async fn new<D>(dst: D) -> Result<Self>
102 where
103 D: TryInto<tonic::transport::Endpoint>,
104 D::Error: Into<StdError>,
105 D::Error: std::fmt::Debug,
106 {
107 Self::with_timeout(dst, RPC_TIMEOUT, None, None).await
108 }
109
110 pub async fn with_timeout<D>(
111 dst: D,
112 timeout: Duration,
113 username: Option<String>,
114 password: Option<String>,
115 ) -> Result<Self>
116 where
117 D: TryInto<tonic::transport::Endpoint>,
118 D::Error: Into<StdError>,
119 D::Error: std::fmt::Debug,
120 {
121 let mut dst: tonic::transport::Endpoint = dst.try_into().map_err(|err| {
122 Error::InvalidParameter("url".to_owned(), format!("to parse {:?}", err))
123 })?;
124
125 dst = dst.timeout(timeout);
126
127 let token = match (username, password) {
128 (Some(username), Some(password)) => {
129 let auth_token = format!("{}:{}", username, password);
130 let auth_token = general_purpose::STANDARD.encode(auth_token);
131 Some(auth_token)
132 }
133 _ => None,
134 };
135
136 let auth_interceptor = AuthInterceptor { token };
137
138 let conn = tonic::transport::Endpoint::new(dst)?.connect().await?;
139
140 let client = MilvusServiceClient::with_interceptor(conn, auth_interceptor);
141
142 Ok(Self { client })
143 }
144
145 pub async fn create_collection(
146 &self,
147 schema: CollectionSchema,
148 options: Option<CreateCollectionOptions>,
149 ) -> Result<Collection> {
150 let options = options.unwrap_or_default();
151 let schema: crate::proto::schema::CollectionSchema = schema.into();
152 let mut buf = BytesMut::new();
153
154 schema.encode(&mut buf)?;
155
156 let status = self
157 .client
158 .clone()
159 .create_collection(CreateCollectionRequest {
160 base: Some(new_msg(MsgType::CreateCollection)),
161 collection_name: schema.name.to_string(),
162 schema: buf.to_vec(),
163 shards_num: options.shard_num,
164 consistency_level: options.consistency_level as i32,
165 ..Default::default()
166 })
167 .await?
168 .into_inner();
169
170 status_to_result(&Some(status))?;
171
172 Ok(self.get_collection(&schema.name).await?)
173 }
174
175 pub async fn get_collection(&self, collection_name: &str) -> Result<Collection> {
176 let resp = self
177 .client
178 .clone()
179 .describe_collection(DescribeCollectionRequest {
180 base: Some(new_msg(MsgType::DescribeCollection)),
181 db_name: "".to_owned(),
182 collection_name: collection_name.to_owned(),
183 collection_id: 0,
184 time_stamp: 0,
185 })
186 .await?
187 .into_inner();
188
189 status_to_result(&resp.status)?;
190
191 Ok(Collection::new(self.client.clone(), resp))
192 }
193
194 pub async fn has_collection<S>(&self, name: S) -> Result<bool>
195 where
196 S: Into<String>,
197 {
198 let name = name.into();
199 let res = self
200 .client
201 .clone()
202 .has_collection(HasCollectionRequest {
203 base: Some(new_msg(MsgType::HasCollection)),
204 db_name: "".to_string(),
205 collection_name: name.clone(),
206 time_stamp: 0,
207 })
208 .await?
209 .into_inner();
210
211 status_to_result(&res.status)?;
212
213 Ok(res.value)
214 }
215
216 pub async fn drop_collection<S>(&self, name: S) -> Result<()>
217 where
218 S: Into<String>,
219 {
220 status_to_result(&Some(
221 self.client
222 .clone()
223 .drop_collection(DropCollectionRequest {
224 base: Some(new_msg(MsgType::DropCollection)),
225 collection_name: name.into(),
226 ..Default::default()
227 })
228 .await?
229 .into_inner(),
230 ))
231 }
232
233 pub async fn list_collections(&self) -> Result<Vec<String>> {
234 let response = self
235 .client
236 .clone()
237 .show_collections(ShowCollectionsRequest {
238 base: Some(new_msg(MsgType::ShowCollections)),
239 ..Default::default()
240 })
241 .await?
242 .into_inner();
243
244 status_to_result(&response.status)?;
245 Ok(response.collection_names)
246 }
247
248 pub async fn flush_collections<C>(&self, collections: C) -> Result<HashMap<String, Vec<i64>>>
249 where
250 C: IntoIterator,
251 C::Item: ToString,
252 {
253 let res = self
254 .client
255 .clone()
256 .flush(FlushRequest {
257 base: Some(new_msg(MsgType::Flush)),
258 db_name: "".to_string(),
259 collection_names: collections.into_iter().map(|x| x.to_string()).collect(),
260 })
261 .await?
262 .into_inner();
263
264 status_to_result(&res.status)?;
265
266 Ok(res
267 .coll_seg_i_ds
268 .into_iter()
269 .map(|(k, v)| (k, v.data))
270 .collect())
271 }
272}