use std::convert::TryInto;
use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::rc::Rc;
use std::time::Duration;
use bytes::Bytes;
use futures::stream::{self, Stream, StreamExt, TryStreamExt};
use reqwest::header::{self, HeaderMap, HeaderValue};
use reqwest::{multipart, IntoUrl, StatusCode, Url};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncWriteExt};
use tokio_util::codec;
use anyhow::{anyhow, Context};
use crate::archive::ArchiveMode;
use crate::error::{ApiError, KaggleError};
use crate::models::extended::{
Competition,
Dataset,
DatasetMetadata,
DatasetNewResponse,
DatasetNewVersionResponse,
File,
FileUploadInfo,
Kernel,
KernelOutput,
KernelPullResponse,
KernelPushResponse,
LeaderBoard,
ListFilesResult,
Submission,
SubmitResult,
};
use crate::models::metadata::{Metadata, Resource};
use crate::models::{
DatasetNew,
DatasetNewRequest,
DatasetNewVersionRequest,
DatasetUpdateSettingsRequest,
DatasetUploadFile,
Error,
KernelPushRequest,
};
use crate::query::{PushKernelType, PushLanguageType};
use crate::request::{CompetitionsList, DatasetsList, KernelPullRequest, KernelsList};
use std::collections::HashMap;
use tempdir::TempDir;
use log::debug;
#[derive(Clone)]
pub struct KaggleApiClient {
client: Rc<reqwest::Client>,
base_url: Url,
credentials: KaggleCredentials,
download_dir: PathBuf,
}
impl KaggleApiClient {
const DATASET_METADATA_FILE: &'static str = "dataset-metadata.json";
const OLD_DATASET_METADATA_FILE: &'static str = "datapackage.json";
const KERNEL_METADATA_FILE: &'static str = "kernel-metadata.json";
#[inline]
pub fn builder() -> KaggleApiClientBuilder {
KaggleApiClientBuilder::default()
}
pub fn download_dir(&self) -> &PathBuf {
&self.download_dir
}
}
#[derive(Debug, Clone)]
pub struct KaggleApiClientBuilder {
base_url: Url,
user_agent: Option<String>,
client: Option<Rc<reqwest::Client>>,
headers: Option<HeaderMap>,
auth: Option<Authentication>,
download_dir: Option<PathBuf>,
}
impl KaggleApiClientBuilder {
pub fn headers(mut self, headers: HeaderMap) -> Self {
self.headers = Some(headers);
self
}
pub fn download_dir<T: Into<PathBuf>>(mut self, download_dir: T) -> Self {
self.download_dir = Some(download_dir.into());
self
}
pub fn headers_mut(&mut self) -> &mut HeaderMap {
if self.headers.is_none() {
self.headers = Some(HeaderMap::with_capacity(2));
}
self.headers.as_mut().unwrap()
}
pub fn user_agent<T: ToString>(mut self, user_agent: T) -> Self {
self.user_agent = Some(user_agent.to_string());
self
}
pub fn client(mut self, client: Rc<reqwest::Client>) -> Self {
self.client = Some(client);
self
}
pub fn auth(mut self, auth: Authentication) -> Self {
self.auth = Some(auth);
self
}
pub fn build(self) -> anyhow::Result<KaggleApiClient> {
let credentials = self
.auth
.unwrap_or_else(Authentication::default)
.credentials()?;
let mut headers = self.headers.unwrap_or_else(|| HeaderMap::with_capacity(2));
let mut header_value = b"Basic ".to_vec();
{
let mut encoder =
base64::write::EncoderWriter::new(&mut header_value, base64::STANDARD);
write!(encoder, "{}:", &credentials.username)?;
write!(encoder, "{}", &credentials.key)?;
}
headers.insert(header::AUTHORIZATION, header_value.try_into()?);
if let Some(user_agent) = self.user_agent {
headers.insert(header::USER_AGENT, user_agent.parse()?);
} else {
headers.insert(
header::USER_AGENT,
HeaderValue::from_static(concat!(
env!("CARGO_PKG_NAME"),
"/",
env!("CARGO_PKG_VERSION"),
)),
);
}
let client = if let Some(client) = self.client {
client
} else {
Rc::new(
reqwest::Client::builder()
.default_headers(headers)
.build()?,
)
};
let download_dir = if let Some(path) = self.download_dir {
path
} else {
std::env::current_dir()?
};
Ok(KaggleApiClient {
client,
base_url: self.base_url,
credentials,
download_dir,
})
}
}
impl Default for KaggleApiClientBuilder {
fn default() -> Self {
Self {
base_url: "https://www.kaggle.com/api/v1/".parse().unwrap(),
user_agent: None,
client: None,
headers: None,
auth: None,
download_dir: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct KaggleCredentials {
username: String,
key: String,
}
impl KaggleCredentials {
fn from_env() -> anyhow::Result<Self> {
let user_name = std::env::var("KAGGLE_USERNAME")
.context("KAGGLE_USERNAME env variable not present.")?;
let key = std::env::var("KAGGLE_KEY").context("KAGGLE_KEY env variable not present.")?;
Ok(KaggleCredentials {
username: user_name,
key,
})
}
fn from_default_json() -> anyhow::Result<Self> {
if let Ok(path) = std::env::var("KAGGLE_CONFIG_DIR") {
Self::from_json(path)
} else {
Self::from_json(
dirs::home_dir()
.map(|p| p.join(".kaggle/kaggle.json"))
.context("Failed to detect home directory.")?,
)
}
}
fn from_json<T: AsRef<Path>>(path: T) -> anyhow::Result<Self> {
let path = path.as_ref();
if !path.exists() {
Err(anyhow!(
"kaggle config file {} does not exist",
path.display()
))
} else {
let content = std::fs::read_to_string(path)?;
Ok(serde_json::from_str(&content)?)
}
}
}
#[derive(Debug, Clone)]
pub enum Authentication {
Env,
ConfigFile { path: Option<PathBuf> },
Credentials { user_name: String, key: String },
}
impl Authentication {
pub fn with_credentials<S: ToString, T: ToString>(user_name: S, key: T) -> Self {
Authentication::Credentials {
user_name: user_name.to_string(),
key: key.to_string(),
}
}
pub fn with_config_file(path: impl AsRef<Path>) -> Self {
Authentication::ConfigFile {
path: Some(path.as_ref().to_path_buf()),
}
}
}
impl Authentication {
fn credentials(self) -> anyhow::Result<KaggleCredentials> {
match self {
Authentication::Env => KaggleCredentials::from_env(),
Authentication::ConfigFile { path } => {
if let Some(path) = path {
KaggleCredentials::from_json(path)
} else {
KaggleCredentials::from_default_json()
}
}
Authentication::Credentials { user_name, key } => Ok(KaggleCredentials {
username: user_name,
key,
}),
}
}
}
impl Default for Authentication {
fn default() -> Self {
Authentication::ConfigFile { path: None }
}
}
pub struct ApiResp;
impl KaggleApiClient {
#[inline]
fn join_url<T: AsRef<str>>(&self, path: T) -> anyhow::Result<Url> {
Ok(self.base_url.join(path.as_ref())?)
}
pub fn get_user_and_identifier_slug<'a>(
&'a self,
id: &'a str,
) -> Result<(&'a str, &'a str), KaggleError> {
let mut split = id.split('/');
if let Some(user) = split.next() {
if let Some(ident) = split.next() {
if split.next().is_none() {
return Ok((user, ident));
}
} else {
return Ok((&self.credentials.username, user));
}
}
Err(KaggleError::meta( format!(
"Invalid identifier string. expected form `{{username}}/{{identifier-slug}}`, but got {}",
id
),
))
}
async fn post_json<T: DeserializeOwned, U: IntoUrl, B: Serialize + ?Sized>(
&self,
url: U,
body: Option<&B>,
) -> anyhow::Result<T> {
let mut req = self.client.post(url).header(
header::ACCEPT,
header::HeaderValue::from_static("application/json"),
);
if let Some(body) = body {
req = req.json(body);
}
Ok(Self::request_json(req).await?)
}
async fn get_json<T: DeserializeOwned, U: IntoUrl>(&self, url: U) -> anyhow::Result<T> {
let url = url.into_url()?;
debug!("GET: {}", url);
Ok(Self::request_json(self.client.get(url)).await?)
}
async fn request_json<T: DeserializeOwned>(req: reqwest::RequestBuilder) -> anyhow::Result<T> {
debug!("Request: {:?}", req);
let full = Self::request(req).await?.bytes().await?;
match serde_json::from_slice::<T>(&full) {
Ok(resp) => Ok(resp),
Err(err) => {
if let Ok(api_err) = serde_json::from_slice::<crate::models::Error>(&full) {
Err(KaggleError::Api {
err: ApiError::ServerError(api_err),
}
.into())
} else {
Err(err.into())
}
}
}
}
async fn request(req: reqwest::RequestBuilder) -> anyhow::Result<reqwest::Response> {
let resp = req.send().await?;
if resp.status().is_success() {
Ok(resp)
} else {
let status = resp.status();
if let Ok(err) = resp.json::<Error>().await {
return Err(KaggleError::Api {
err: ApiError::ServerError(err),
}
.into());
}
let err = match status {
StatusCode::UNAUTHORIZED => ApiError::Unauthorized,
status => ApiError::Other(status.as_u16()),
};
Err(KaggleError::Api { err }.into())
}
}
async fn write_resp(
mut res: reqwest::Response,
output: impl AsRef<Path>,
) -> anyhow::Result<PathBuf> {
let output = output.as_ref();
let mut file = tokio::fs::File::create(output).await?;
while let Some(chunk) = res.chunk().await? {
file.write_all(&chunk).await?;
}
Ok(output.to_path_buf())
}
async fn download_file(
req: reqwest::RequestBuilder,
output: impl AsRef<Path>,
) -> anyhow::Result<PathBuf> {
Ok(Self::write_resp(Self::request(req).await?, output).await?)
}
pub(crate) async fn read_dataset_metadata_file(
path: impl AsRef<Path>,
) -> anyhow::Result<Metadata> {
let meta_file = Self::get_dataset_metadata_file(path)?;
let file = tokio::fs::read(&meta_file).await?;
Ok(serde_json::from_slice(&file)?)
}
async fn read_kernel_metadata_file(path: impl AsRef<Path>) -> anyhow::Result<Metadata> {
let meta_file = Self::get_kernel_metadata_file(path)?;
let file = tokio::fs::read(&meta_file).await?;
Ok(serde_json::from_slice(&file)?)
}
fn get_dataset_metadata_file(path: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
let path = path.as_ref().to_path_buf();
if path.is_dir() {
let file = path.join(Self::DATASET_METADATA_FILE);
if !file.exists() {
let old = path.join(Self::OLD_DATASET_METADATA_FILE);
if old.exists() {
Ok(old)
} else {
Err(KaggleError::FileNotFound(file).into())
}
} else {
Ok(file)
}
} else if path.exists() {
Ok(path)
} else {
Err(KaggleError::FileNotFound(path).into())
}
}
fn get_kernel_metadata_file(path: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
let path = path.as_ref().to_path_buf();
if path.is_dir() {
let file = path.join(Self::KERNEL_METADATA_FILE);
if file.exists() {
Ok(file)
} else {
Err(KaggleError::FileNotFound(file).into())
}
} else if path.exists() {
Ok(path)
} else {
Err(KaggleError::FileNotFound(path).into())
}
}
fn get_file_metadata(file: impl AsRef<Path>) -> anyhow::Result<(u64, Duration)> {
let file = file.as_ref();
let meta = file.metadata()?;
let content_length = meta.len();
let last_modified = meta
.modified()
.unwrap_or_else(|_| std::time::SystemTime::now())
.elapsed()?;
Ok((content_length, last_modified))
}
async fn upload_dataset_file(
&self,
file: impl AsRef<Path>,
file_name: impl AsRef<str>,
item: Option<&Resource>,
) -> anyhow::Result<DatasetUploadFile> {
let file = file.as_ref();
let (content_length, last_modified) = Self::get_file_metadata(file)?;
let info = self
.datasets_upload_file(file_name.as_ref(), content_length, last_modified)
.await?;
self.upload_complete(file, &info.create_url).await?;
let mut upload_file = DatasetUploadFile::new(info.token);
if let Some(item) = item {
if let Some(desc) = &item.description {
upload_file.set_description(desc.clone());
}
if let Some(schema) = &item.schema {
upload_file.set_columns(schema.get_processed_columns());
}
if let Some(schema) = &item.schema {
upload_file.set_columns(schema.get_processed_columns());
}
}
Ok(upload_file)
}
async fn upload_files(
&self,
folder: impl AsRef<Path>,
resources: &[Resource],
dir_mode: ArchiveMode,
) -> anyhow::Result<Vec<DatasetUploadFile>> {
let mut uploads = Vec::with_capacity(resources.len());
let folder = folder.as_ref();
let resource_paths: HashMap<_, _> = resources
.iter()
.map(|x| (folder.join(&x.path), x))
.collect();
let mut tmp_archive_dir = None;
let skip = &[
Self::DATASET_METADATA_FILE,
Self::OLD_DATASET_METADATA_FILE,
Self::KERNEL_METADATA_FILE,
];
for (entry, resource) in resource_paths {
if !entry.exists() {
continue;
}
let file_name = entry
.file_name()
.context("File path terminates in `..`")?
.to_str()
.context("File name is not valid unicode")?
.to_string();
let mut upload = None;
if entry.is_file() {
if skip.contains(&file_name.as_str()) {
continue;
}
upload = Some(entry);
} else if entry.is_dir() {
if tmp_archive_dir.is_none() {
tmp_archive_dir = Some(TempDir::new("kaggle-upload")?);
}
let archive_path = tmp_archive_dir.as_ref().unwrap().path().join(&file_name);
upload = dir_mode.make_archive(entry, &archive_path)?;
}
if let Some(upload) = upload {
let upload_file = self
.upload_dataset_file(upload, &file_name, Some(&resource))
.await?;
uploads.push(upload_file);
}
}
if let Some(tmp) = tmp_archive_dir {
tmp.close()?;
}
Ok(uploads)
}
}
impl KaggleApiClient {
pub async fn competitions_list(
&self,
competition: &CompetitionsList,
) -> anyhow::Result<Vec<Competition>> {
Ok(Self::request_json(
self.client
.get(self.join_url("competitions/list")?)
.query(competition),
)
.await?)
}
pub async fn competition_download_leaderboard(
&self,
id: impl AsRef<str>,
output: Option<PathBuf>,
) -> anyhow::Result<PathBuf> {
let id = id.as_ref();
let output = if let Some(target) = output {
if target.is_dir() {
target.join(format!("{}-leaderboard.zip", id))
} else {
target
}
} else {
self.download_dir.join(format!("{}-leaderboard.zip", id))
};
Ok(Self::download_file(
self.client
.get(self.join_url(format!("competitions/{}/leaderboard/download", id))?),
output,
)
.await?)
}
pub async fn competition_view_leaderboard(
&self,
id: impl AsRef<str>,
) -> anyhow::Result<LeaderBoard> {
Ok(Self::request_json(
self.client
.get(self.join_url(format!("competitions/{}/leaderboard/view", id.as_ref()))?),
)
.await?)
}
pub async fn competitions_data_download_file(
&self,
id: impl AsRef<str>,
file_name: impl AsRef<str>,
target: Option<PathBuf>,
) -> anyhow::Result<PathBuf> {
let file_name = file_name.as_ref();
let output = target.unwrap_or_else(|| self.download_dir.join(format!("{}.zip", file_name)));
Ok(Self::download_file(
self.client.get(self.join_url(format!(
"competitions/data/download/{}/{}",
id.as_ref(),
file_name
))?),
output,
)
.await?)
}
pub async fn competitions_data_download_all_files(
&self,
id: impl AsRef<str>,
target: Option<PathBuf>,
) -> anyhow::Result<PathBuf> {
let id = id.as_ref();
let output = target.unwrap_or_else(|| self.download_dir.join(format!("{}.zip", id)));
Ok(Self::download_file(
self.client
.get(self.join_url(format!("competitions/data/download-all/{}", id))?),
output,
)
.await?)
}
pub async fn competitions_data_list_files(
&self,
id: impl AsRef<str>,
) -> anyhow::Result<Vec<File>> {
Ok(Self::request_json(
self.client
.get(self.join_url(format!("competitions/data/list/{}", id.as_ref()))?),
)
.await?)
}
pub async fn competitions_submissions_list(
&self,
id: impl AsRef<str>,
page: usize,
) -> anyhow::Result<Vec<Submission>> {
let req = self
.client
.get(self.join_url(format!("competitions/submissions/list/{}", id.as_ref()))?)
.query(&[("page", page)]);
Ok(Self::request_json(req).await?)
}
pub async fn competitions_submissions_submit(
&self,
id: impl AsRef<str>,
blob_file_tokens: impl ToString,
submission_description: impl ToString,
) -> anyhow::Result<SubmitResult> {
let form = multipart::Form::new()
.text("blobFileTokens", blob_file_tokens.to_string())
.text("submissionDescription", submission_description.to_string());
Ok(Self::request_json(
self.client
.post(self.join_url(format!("competitions/submissions/submit/{}", id.as_ref()))?)
.multipart(form),
)
.await?)
}
pub async fn competition_submit(
&self,
file: impl AsRef<Path>,
competition: impl AsRef<str>,
message: impl ToString,
) -> anyhow::Result<SubmitResult> {
let competition = competition.as_ref();
let file = file.as_ref();
let (content_length, last_modified) = Self::get_file_metadata(&file)?;
let file_name = file
.file_name()
.context("File path terminates in `..`")?
.to_str()
.context("File name is not valid unicode")?;
let url_result = self
.competitions_submissions_url(&competition, content_length, last_modified, file_name)
.await?;
let obj = url_result
.as_object()
.context("Expected json response object")?;
let upload_result = if obj.get("isComplete").is_some() {
let url_list = obj
.get("createUrl")
.and_then(serde_json::Value::as_str)
.context("Missing `createUrl` field")?;
let parts: Vec<_> = url_list.split('/').rev().collect();
if parts.len() < 3 {
return Err(anyhow!(
"createUrl response with incomplete segments {}",
url_list
));
}
self.competitions_submissions_upload(
file,
parts[0],
parts[1].parse()?,
Duration::from_secs(parts[2].parse()?),
)
.await?
} else {
self.upload_complete(
file,
obj.get("createUrl")
.and_then(serde_json::Value::as_str)
.context("Missing createUrl in response")?,
)
.await?;
url_result
};
let token = upload_result
.as_object()
.and_then(|x| x.get("token"))
.and_then(serde_json::Value::as_str)
.context("Missing upload token")?;
Ok(self
.competitions_submissions_submit(competition, token, message)
.await?)
}
async fn upload_complete(
&self,
file: impl AsRef<Path>,
url: impl IntoUrl,
) -> anyhow::Result<reqwest::Response> {
let stream = into_bytes_stream(tokio::fs::File::open(file).await?);
Ok(Self::request(
self.client
.put(url)
.body(reqwest::Body::wrap_stream(stream)),
)
.await?)
}
async fn competitions_submissions_upload(
&self,
file: impl AsRef<Path>,
guid: impl AsRef<str>,
content_length: u64,
last_modified_date_utc: Duration,
) -> anyhow::Result<serde_json::Value> {
let stream = into_bytes_stream(tokio::fs::File::open(file).await?);
let form = multipart::Form::new().part(
"file",
multipart::Part::stream(reqwest::Body::wrap_stream(stream)),
);
let req = self
.client
.post(self.join_url(format!(
"competitions/submissions/upload/{}/{}/{}",
guid.as_ref(),
content_length,
last_modified_date_utc.as_secs()
))?)
.multipart(form);
Ok(Self::request_json(req).await?)
}
async fn competitions_submissions_url(
&self,
id: impl AsRef<str>,
content_length: u64,
last_modified_date_utc: Duration,
file_name: impl ToString,
) -> anyhow::Result<serde_json::Value> {
let form = multipart::Form::new().text("fileName", file_name.to_string());
let req = self
.client
.post(self.join_url(format!(
"competitions/{}/submissions/url/{}/{}",
id.as_ref(),
content_length,
last_modified_date_utc.as_secs()
))?)
.multipart(form);
Ok(Self::request_json(req).await?)
}
pub async fn dataset_create_new(
&self,
new_dataset: DatasetNew,
) -> anyhow::Result<DatasetNewResponse> {
new_dataset.validate_resources()?;
let metadata = new_dataset.metadata;
let (owner_slug, dataset_slug) = self
.get_user_and_identifier_slug(&metadata.id)
.map(|(s1, s2)| (s1.to_string(), s2.to_string()))?;
if dataset_slug == "INSERT_SLUG_HERE" {
return Err(KaggleError::meta(
"Default slug detected, please change values before uploading",
)
.into());
}
if metadata.title == "INSERT_SLUG_HERE" {
return Err(KaggleError::meta(
"Default title detected, please change values before uploading",
)
.into());
}
if metadata.licenses.len() != 1 {
return Err(KaggleError::meta("Please specify exactly one license").into());
}
if dataset_slug.len() < 6 || dataset_slug.len() > 50 {
return Err(
KaggleError::meta("The dataset slug must be between 6 and 50 characters").into(),
);
}
if metadata.title.len() < 6 || metadata.title.len() > 50 {
return Err(
KaggleError::meta("The dataset title must be between 6 and 50 characters").into(),
);
}
let mut request = DatasetNewRequest::builder(metadata.title);
if let Some(subtitle) = &metadata.subtitle {
if subtitle.len() < 20 || subtitle.len() > 80 {
return Err(KaggleError::meta(
"Subtitle length must be between 20 and 80 characters",
)
.into());
}
request = request.subtitle(subtitle);
}
let files = if let Some(folder) = new_dataset.dataset_folder {
self.upload_files(folder, &metadata.resources, new_dataset.archive_mode)
.await?
} else {
vec![]
};
let mut request = request
.slug(dataset_slug)
.owner_slug(owner_slug)
.license_name(metadata.licenses[0].to_string())
.with_private(new_dataset.is_private)
.convert_to_csv(new_dataset.convert_to_csv)
.category_ids(metadata.keywords)
.files(files);
if let Some(desc) = metadata.description {
request = request.description(desc);
}
Ok(self.datasets_create_new(&request.build()).await?)
}
async fn datasets_create_new(
&self,
new_dataset: &DatasetNewRequest,
) -> anyhow::Result<DatasetNewResponse> {
Ok(self
.post_json(self.join_url("datasets/create/new")?, Some(new_dataset))
.await?)
}
pub async fn dataset_create_version(
&self,
folder: impl AsRef<Path>,
version_notes: impl ToString,
convert_to_csv: bool,
delete_old_versions: bool,
archive_mode: ArchiveMode,
) -> anyhow::Result<DatasetNewVersionResponse> {
let folder = folder.as_ref();
let meta_data = Self::read_dataset_metadata_file(folder).await?;
meta_data.validate_resource(folder)?;
let mut req = DatasetNewVersionRequest::new(version_notes.to_string());
if let Some(subtitle) = meta_data.subtitle {
if subtitle.len() < 20 || subtitle.len() > 80 {
return Err(KaggleError::Metadata {
msg: "Subtitle length must be between 20 and 80 characters".to_string(),
}
.into());
}
req.set_subtitle(subtitle);
}
let files = self
.upload_files(folder, &meta_data.resources, archive_mode)
.await?;
if let Some(desc) = meta_data.description {
req.set_description(desc);
}
req.set_category_ids(meta_data.keywords);
req.set_convert_to_csv(convert_to_csv);
req.set_delete_old_versions(delete_old_versions);
req.set_files(files);
if let Some(id_no) = meta_data.id_no {
Ok(self.datasets_create_version_by_id(id_no, &req).await?)
} else {
if meta_data.id == format!("{}/INSERT_SLUG_HERE", self.credentials.username) {
return Err(KaggleError::Metadata {
msg: "Default slug detected, please change values before uploading".to_string(),
}
.into());
}
Ok(self.datasets_create_version(&meta_data.id, &req).await?)
}
}
pub async fn datasets_create_version(
&self,
name: &str,
dataset_req: &DatasetNewVersionRequest,
) -> anyhow::Result<DatasetNewVersionResponse> {
let (owner_slug, dataset_slug) = self.get_user_and_identifier_slug(name)?;
Ok(self
.post_json(
self.join_url(format!(
"datasets/create/version/{}/{}",
owner_slug, dataset_slug
))?,
Some(dataset_req),
)
.await?)
}
pub async fn datasets_create_version_by_id(
&self,
id: i32,
dataset_req: &DatasetNewVersionRequest,
) -> anyhow::Result<DatasetNewVersionResponse> {
Ok(self
.post_json(
self.join_url(format!("datasets/create/version/{}", id))?,
Some(dataset_req),
)
.await?)
}
pub async fn dataset_download_all_files(
&self,
name: impl AsRef<str>,
path: Option<PathBuf>,
dataset_version_number: Option<&str>,
) -> anyhow::Result<PathBuf> {
let (owner_slug, dataset_slug) = self.get_user_and_identifier_slug(name.as_ref())?;
let mut req = self
.client
.get(self.join_url(format!("datasets/download/{}/{}", owner_slug, dataset_slug))?)
.header(header::ACCEPT, HeaderValue::from_static("file"));
if let Some(version) = dataset_version_number {
req = req.query(&[("datasetVersionNumber", version)]);
}
let folder = path.unwrap_or_else(|| {
self.download_dir
.join(format!("datasets/{}/{}", owner_slug, dataset_slug,))
});
fs::create_dir_all(&folder)?;
let outfile =
Self::download_file(req, folder.join(format!("{}.zip", dataset_slug))).await?;
Ok(outfile)
}
pub async fn dataset_download_file(
&self,
name: impl AsRef<str>,
file_name: impl AsRef<str>,
folder: Option<PathBuf>,
dataset_version_number: Option<&str>,
) -> anyhow::Result<PathBuf> {
let (owner_slug, dataset_slug) = self.get_user_and_identifier_slug(name.as_ref())?;
let mut req = self
.client
.get(self.join_url(format!(
"datasets/download/{}/{}/{}",
owner_slug,
dataset_slug,
file_name.as_ref()
))?)
.header(header::ACCEPT, HeaderValue::from_static("file"));
if let Some(version) = dataset_version_number {
req = req.query(&[("datasetVersionNumber", version)]);
}
let resp = Self::request(req).await?;
let url = resp
.url()
.path_segments()
.context("redirected to invalid dataset download url")?
.last()
.context("no file segment in url download path")?;
let output = folder.unwrap_or_else(|| {
self.download_dir
.join(format!("datasets/{}/{}", owner_slug, dataset_slug))
});
fs::create_dir_all(&output)?;
let outfile = output.join(url);
Ok(Self::write_resp(resp, outfile).await?)
}
pub async fn datasets_list(&self, list: &DatasetsList) -> anyhow::Result<Vec<Dataset>> {
Ok(
Self::request_json(self.client.get(self.join_url("datasets/list")?).query(list))
.await?,
)
}
pub async fn datasets_list_files(
&self,
name: impl AsRef<str>,
) -> anyhow::Result<ListFilesResult> {
let (owner_slug, dataset_slug) = self.get_user_and_identifier_slug(name.as_ref())?;
Ok(Self::request_json(
self.client
.get(self.join_url(format!("datasets/list/{}/{}", owner_slug, dataset_slug))?),
)
.await?)
}
pub async fn datasets_status(
&self,
name: impl AsRef<str>,
) -> anyhow::Result<Option<serde_json::Value>> {
let (owner_slug, dataset_slug) = self.get_user_and_identifier_slug(name.as_ref())?;
Ok(self
.get_json(self.join_url(format!("datasets/status/{}/{}", owner_slug, dataset_slug))?)
.await?)
}
pub async fn datasets_upload_file(
&self,
file_name: impl ToString,
content_length: u64,
last_modified_date_utc: Duration,
) -> anyhow::Result<FileUploadInfo> {
let form = multipart::Form::new().text("fileName", file_name.to_string());
Ok(Self::request_json(
self.client
.post(self.join_url(format!(
"datasets/upload/file/{}/{}",
content_length,
last_modified_date_utc.as_secs()
))?)
.multipart(form),
)
.await?)
}
pub async fn datasets_view(&self, name: impl AsRef<str>) -> anyhow::Result<Dataset> {
let (owner_slug, dataset_slug) = self.get_user_and_identifier_slug(name.as_ref())?;
Ok(self
.get_json(self.join_url(format!("datasets/view/{}/{}", owner_slug, dataset_slug))?)
.await?)
}
pub async fn kernels_output(
&self,
name: impl AsRef<str>,
path: Option<PathBuf>,
) -> anyhow::Result<Vec<PathBuf>> {
let name = name.as_ref();
let (owner_slug, kernel_slug) = self.get_user_and_identifier_slug(name)?;
let folder = path.unwrap_or_else(|| {
self.download_dir
.join(format!("datasets/{}/{}/output", owner_slug, kernel_slug,))
});
fs::create_dir_all(&folder)?;
let resp = self.kernel_output(name).await?;
let mut outfiles = Vec::with_capacity(resp.files.len());
let mut outstream = stream::iter(resp.files.into_iter().map(|file| async {
let outfile = folder.join(file.file_name);
let content = file.url.content;
tokio::fs::write(&outfile, content).await?;
Ok::<_, std::io::Error>(outfile)
}))
.buffer_unordered(3);
while let Some(f) = outstream.next().await {
outfiles.push(f?);
}
if let Some(log) = resp.log {
let outfile = folder.join(format!("{}.log", kernel_slug));
tokio::fs::write(&outfile, log).await?;
outfiles.push(outfile);
}
Ok(outfiles)
}
pub async fn kernel_output(&self, name: impl AsRef<str>) -> anyhow::Result<KernelOutput> {
let (owner_slug, kernel_slug) = self.get_user_and_identifier_slug(name.as_ref())?;
if kernel_slug.len() < 5 {
return Err(KaggleError::meta(format!(
"Kernel slug `{}` must be at least five characters.",
kernel_slug
))
.into());
}
Ok(self
.get_json(self.join_url(format!(
"kernels/output?userName={}&kernelSlug={}",
owner_slug, kernel_slug
))?)
.await?)
}
pub async fn kernel_pull(&self, name: impl AsRef<str>) -> anyhow::Result<KernelPullResponse> {
let (owner_slug, kernel_slug) = self.get_user_and_identifier_slug(name.as_ref())?;
Ok(self
.get_json(self.join_url(format!(
"kernels/pull?userName={}&kernelSlug={}",
owner_slug, kernel_slug
))?)
.await?)
}
pub async fn kernels_pull(
&self,
pull: KernelPullRequest,
) -> anyhow::Result<(PathBuf, Option<PathBuf>)> {
let (owner_slug, kernel_slug) = self.get_user_and_identifier_slug(&pull.name)?;
let resp = self.kernel_pull(&pull.name).await?;
let folder = pull.output.unwrap_or_else(|| {
self.download_dir
.join(format!("kernels/{}/{}", owner_slug, kernel_slug))
});
fs::create_dir_all(&folder)?;
let metadata_path = folder.join(Self::KERNEL_METADATA_FILE);
let file_name = if metadata_path.exists() {
let existing_meta = Self::read_kernel_metadata_file(&metadata_path).await?;
if Some("INSERT_CODE_FILE_PATH_HERE") == existing_meta.code_file.as_deref() {
None
} else {
existing_meta.code_file
}
} else {
resp.code_file_name()
}
.unwrap_or_else(|| "script.py".to_string());
let output = folder.join(file_name);
tokio::fs::write(&output, resp.blob.source).await?;
if pull.with_metadata {
tokio::fs::write(
&metadata_path,
serde_json::to_string_pretty(&resp.metadata)?,
)
.await?;
Ok((output, Some(metadata_path)))
} else {
Ok((output, None))
}
}
pub async fn kernels_push(
&self,
folder: impl AsRef<Path>,
) -> anyhow::Result<KernelPushResponse> {
let folder = folder.as_ref();
let metadata = Self::read_kernel_metadata_file(folder).await?;
if metadata.title.len() < 5 {
return Err(KaggleError::meta("Title must be at least five characters").into());
}
metadata.is_dataset_sources_valid()?;
metadata.is_kernel_sources_valid()?;
let code_path = metadata
.code_file
.ok_or_else(|| KaggleError::meta("A source file must be specified in the metadata"))?;
let code_file = folder.join(code_path);
if !code_file.is_file() && !code_file.exists() {
return Err(KaggleError::meta(format!(
"Source file not found:{}",
code_file.display()
))
.into());
}
let (_owner_slug, kernel_slug) = self
.get_user_and_identifier_slug(&metadata.id)
.map(|(s1, s2)| (s1.to_string(), s2.to_string()))?;
if kernel_slug.to_lowercase() != slug::slugify(&metadata.title) {
return Err(
KaggleError::meta("kernel title does not resolve to the specified id").into(),
);
}
let script_body = tokio::fs::read(&code_file).await?;
let text = if Some(PushKernelType::Notebook) == metadata.kernel_type {
let mut json_body = serde_json::from_slice::<serde_json::Value>(&script_body)?;
let obj = json_body
.as_object_mut()
.context("Expected json object in code file")?;
if let Some(cells) = obj.get_mut("cells").and_then(|x| x.as_array_mut()) {
for cell in cells {
if let Some(cell_obj) = cell.as_object_mut() {
if cell_obj.contains_key("outputs")
&& Some("code") == cell_obj.get("cell_type").and_then(|x| x.as_str())
{
cell_obj
.insert("outputs".to_string(), serde_json::Value::Array(vec![]));
}
}
}
}
serde_json::to_string(&json_body)?
} else {
String::from_utf8_lossy(&script_body).to_string()
};
let language = if Some(PushKernelType::Notebook) == metadata.kernel_type
&& Some(PushLanguageType::Rmarkdown) == metadata.language
{
Some(PushLanguageType::R)
} else {
metadata.language
};
let mut req = KernelPushRequest::new(text)
.with_new_title(metadata.title)
.with_slug(metadata.id)
.with_dataset_data_sources(metadata.dataset_sources)
.with_competition_data_sources(metadata.competition_sources)
.with_kernel_data_sources(metadata.kernel_sources)
.with_category_ids(metadata.keywords);
if let Some(id_no) = metadata.id_no {
req.set_id(id_no);
}
if let Some(language) = language {
req.set_language(language);
}
if let Some(kernel) = metadata.kernel_type {
req.set_kernel_type(kernel);
}
if let Some(enable_gpu) = metadata.enable_gpu {
req.set_enable_gpu(enable_gpu);
}
if let Some(enable_internet) = metadata.enable_internet {
req.set_enable_internet(enable_internet);
}
if let Some(is_private) = metadata.is_private {
req.set_is_private(is_private);
}
Ok(self.kernel_push(&req).await?)
}
pub async fn kernel_push(
&self,
kernel_push_request: &KernelPushRequest,
) -> anyhow::Result<KernelPushResponse> {
Ok(self
.post_json(self.join_url("kernels/push")?, Some(kernel_push_request))
.await?)
}
pub async fn kernel_status(&self, name: impl AsRef<str>) -> anyhow::Result<serde_json::Value> {
let (owner_slug, kernel_slug) = self.get_user_and_identifier_slug(name.as_ref())?;
Ok(Self::request_json(self.client.get(self.join_url(format!(
"kernels/status?userName={}&kernelSlug={}",
owner_slug, kernel_slug
))?))
.await?)
}
pub async fn kernels_list(&self, kernel_list: &KernelsList) -> anyhow::Result<Vec<Kernel>> {
Ok(Self::request_json(
self.client
.get(self.join_url("kernels/list")?)
.query(kernel_list),
)
.await?)
}
pub async fn metadata_get(&self, name: impl AsRef<str>) -> anyhow::Result<DatasetMetadata> {
let (owner_slug, dataset_slug) = self.get_user_and_identifier_slug(name.as_ref())?;
Ok(Self::request_json(
self.client
.get(self.join_url(format!("datasets/metadata/{}/{}", owner_slug, dataset_slug))?),
)
.await?)
}
pub async fn dataset_metadata_update(
&self,
name: impl AsRef<str>,
path: Option<PathBuf>,
) -> anyhow::Result<serde_json::Value> {
let name = name.as_ref();
let metadata = if let Some(path) = path {
Self::read_dataset_metadata_file(path).await?
} else {
let (owner_slug, dataset_slug) = self.get_user_and_identifier_slug(name)?;
Self::read_dataset_metadata_file(
self.download_dir
.join(format!("datasets/{}/{}", owner_slug, dataset_slug)),
)
.await?
};
let settings = metadata.into();
Ok(self.metadata_post(name, &settings).await?)
}
pub async fn metadata_post(
&self,
name: impl AsRef<str>,
settings: &DatasetUpdateSettingsRequest,
) -> anyhow::Result<serde_json::Value> {
let (owner_slug, dataset_slug) = self.get_user_and_identifier_slug(name.as_ref())?;
Ok(self
.post_json(
self.join_url(format!("datasets/metadata/{}/{}", owner_slug, dataset_slug))?,
Some(settings),
)
.await?)
}
}
fn into_bytes_stream<R>(r: R) -> impl Stream<Item = tokio::io::Result<Bytes>>
where
R: AsyncRead,
{
codec::FramedRead::new(r, codec::BytesCodec::new()).map_ok(|bytes| bytes.freeze())
}
#[cfg(test)]
mod tests {
use super::*;
fn kaggle() -> KaggleApiClient {
KaggleApiClient::builder()
.auth(Authentication::with_credentials("name", "key"))
.build()
.unwrap()
}
#[test]
fn competition_query() {
let kaggle = kaggle();
let req = kaggle
.client
.get(kaggle.join_url("competitions/list").unwrap())
.query(&CompetitionsList::default())
.build()
.unwrap();
assert_eq!(
*req.url(),
format!(
"{}?group=&category=&sortBy=&page=1&search=",
kaggle.join_url("competitions/list").unwrap()
)
.parse()
.unwrap()
)
}
}