grafbase_local_backend/
project.rs1use crate::consts::{DEFAULT_DOT_ENV, DEFAULT_SCHEMA, USER_AGENT};
2use crate::errors::BackendError;
3use async_compression::tokio::bufread::GzipDecoder;
4use async_tar::Archive;
5use common::consts::{GRAFBASE_DIRECTORY_NAME, GRAFBASE_ENV_FILE_NAME, GRAFBASE_SCHEMA_FILE_NAME};
6use common::environment::Environment;
7use http_cache_reqwest::{CACacheManager, Cache, CacheMode, HttpCache};
8use reqwest::{header, Client};
9use reqwest_middleware::ClientBuilder;
10use serde::Deserialize;
11use std::env;
12use std::fs;
13use std::io::{Error as IoError, ErrorKind as IoErrorKind};
14use std::iter::Iterator;
15use std::path::PathBuf;
16use tokio_stream::StreamExt;
17use tokio_util::compat::TokioAsyncReadCompatExt;
18use tokio_util::io::StreamReader;
19use url::Url;
20
21#[tokio::main]
63pub async fn init(name: Option<&str>, template: Option<&str>) -> Result<(), BackendError> {
64 let project_path = to_project_path(name)?;
65 let grafbase_path = project_path.join(GRAFBASE_DIRECTORY_NAME);
66 let schema_path = grafbase_path.join(GRAFBASE_SCHEMA_FILE_NAME);
67
68 if grafbase_path.exists() {
69 Err(BackendError::AlreadyAProject(grafbase_path))
70 } else if let Some(template) = template {
71 if template.contains('/') {
75 if let Ok(repo_url) = Url::parse(template) {
76 match repo_url.host_str() {
77 Some("github.com") => handle_github_repo_url(grafbase_path, &repo_url).await,
78 _ => Err(BackendError::UnsupportedTemplateURL(template.to_string())),
79 }
80 } else {
81 return Err(BackendError::MalformedTemplateURL(template.to_owned()));
82 }
83 } else {
84 download_github_template(
85 grafbase_path,
86 GitHubTemplate::Grafbase(GrafbaseGithubTemplate { path: template }),
87 )
88 .await
89 }
90 } else {
91 tokio::fs::create_dir_all(&grafbase_path)
92 .await
93 .map_err(BackendError::CreateGrafbaseDirectory)?;
94
95 let dot_env_path = grafbase_path.join(GRAFBASE_ENV_FILE_NAME);
96 let schema_write_result = fs::write(schema_path, DEFAULT_SCHEMA).map_err(BackendError::WriteSchema);
97 let dot_env_write_result = fs::write(dot_env_path, DEFAULT_DOT_ENV).map_err(BackendError::WriteSchema);
98
99 if schema_write_result.is_err() || dot_env_write_result.is_err() {
100 tokio::fs::remove_dir_all(&grafbase_path)
101 .await
102 .map_err(BackendError::DeleteGrafbaseDirectory)?;
103 }
104
105 schema_write_result?;
106 dot_env_write_result?;
107
108 Ok(())
109 }
110}
111
112async fn handle_github_repo_url(grafbase_path: PathBuf, repo_url: &Url) -> Result<(), BackendError> {
113 if let Some(mut segments) = repo_url.path_segments().map(Iterator::collect::<Vec<_>>) {
114 if segments.last() == Some(&"") {
116 segments.pop();
117 }
118
119 if segments.contains(&"") {
121 return Err(BackendError::UnsupportedTemplateURL(repo_url.to_string()));
122 }
123
124 match segments.len() {
125 2 => {
126 let org = &segments[0];
127
128 let repo = &segments[1];
129
130 let branch = get_default_branch(org, repo).await?;
131
132 download_github_template(
133 grafbase_path,
134 GitHubTemplate::External(ExternalGitHubTemplate {
135 org,
136 repo,
137 branch: &branch,
138 path: None,
139 }),
140 )
141 .await
142 }
143 4.. if segments[2] == "tree" => {
144 let org = &segments[0];
145
146 let repo = &segments[1];
147
148 let branch = &segments[3];
149
150 let path = segments.get(4..).map(|path| path.join("/"));
151
152 download_github_template(
153 grafbase_path,
154 GitHubTemplate::External(ExternalGitHubTemplate {
155 org,
156 repo,
157 path,
158 branch,
159 }),
160 )
161 .await
162 }
163 _ => Err(BackendError::UnsupportedTemplateURL(repo_url.to_string())),
164 }
165 } else {
166 Err(BackendError::UnsupportedTemplateURL(repo_url.to_string()))
167 }
168}
169
170#[derive(Deserialize)]
171struct RepoInfo {
172 default_branch: String,
173}
174
175async fn get_default_branch(org: &str, repo: &str) -> Result<String, BackendError> {
176 let client = Client::new();
177
178 let response = client
179 .get(format!("https://api.github.com/repos/{org}/{repo}"))
180 .header(header::USER_AGENT, USER_AGENT)
182 .send()
183 .await
184 .map_err(|_| BackendError::StartGetRepositoryInformation(format!("{org}/{repo}")))?;
185
186 if !response.status().is_success() {
187 return Err(BackendError::GetRepositoryInformation(format!("{org}/{repo}")));
188 }
189
190 let repo_info = response
191 .json::<RepoInfo>()
192 .await
193 .map_err(|_| BackendError::ReadRepositoryInformation(format!("{org}/{repo}")))?;
194
195 Ok(repo_info.default_branch)
196}
197
198fn to_project_path(name: Option<&str>) -> Result<PathBuf, BackendError> {
199 let current_dir = env::current_dir().map_err(|_| BackendError::ReadCurrentDirectory)?;
200 match name {
201 Some(name) => {
202 let project_path = current_dir.join(name);
203 if project_path.exists() {
204 Err(BackendError::ProjectDirectoryExists(project_path))
205 } else {
206 Ok(project_path)
207 }
208 }
209 None => Ok(current_dir),
210 }
211}
212
213#[derive(Clone)]
214struct ExternalGitHubTemplate<'a> {
215 org: &'a str,
216 repo: &'a str,
217 path: Option<String>,
218 branch: &'a str,
219}
220
221struct GrafbaseGithubTemplate<'a> {
222 path: &'a str,
223}
224
225enum GitHubTemplate<'a> {
226 Grafbase(GrafbaseGithubTemplate<'a>),
227 External(ExternalGitHubTemplate<'a>),
228}
229
230impl<'a> GitHubTemplate<'a> {
231 pub fn into_external_github_template(self) -> ExternalGitHubTemplate<'a> {
232 match self {
233 Self::Grafbase(GrafbaseGithubTemplate { path }) => ExternalGitHubTemplate {
234 org: "grafbase",
235 repo: "grafbase",
236 path: Some(format!("templates/{path}")),
237 branch: "main",
238 },
239 Self::External(template @ ExternalGitHubTemplate { .. }) => template,
240 }
241 }
242}
243
244async fn download_github_template(grafbase_path: PathBuf, template: GitHubTemplate<'_>) -> Result<(), BackendError> {
245 let ExternalGitHubTemplate {
246 org,
247 repo,
248 path,
249 branch,
250 } = template.into_external_github_template();
251
252 let org_and_repo = format!("{org}/{repo}");
253
254 let extraction_dir = PathBuf::from(format!("{repo}-{branch}"));
255
256 let mut template_path: PathBuf = PathBuf::from(&extraction_dir);
257
258 if let Some(path) = path {
259 template_path.push(path);
260 }
261
262 template_path.push("grafbase");
263
264 let extraction_result = stream_github_archive(grafbase_path, &org_and_repo, template_path, branch).await;
265
266 if extraction_dir.exists() {
267 tokio::fs::remove_dir_all(extraction_dir)
268 .await
269 .map_err(BackendError::CleanExtractedFiles)?;
270 }
271
272 extraction_result
273}
274
275async fn stream_github_archive<'a>(
276 grafbase_path: PathBuf,
277 org_and_repo: &'a str,
278 template_path: PathBuf,
279 branch: &'a str,
280) -> Result<(), BackendError> {
281 let cache_directory = dirs::cache_dir().and_then(|path| path.join("grafbase").to_str().map(ToOwned::to_owned));
284
285 let mut client_builder = ClientBuilder::new(Client::new());
286
287 if let Some(cache_directory) = cache_directory {
288 client_builder = client_builder.with(Cache(HttpCache {
289 mode: CacheMode::Default,
290 manager: CACacheManager { path: cache_directory },
291 options: None,
292 }));
293 }
294
295 let client = client_builder.build();
296
297 let tar_gz_response = client
298 .get(format!("https://codeload.github.com/{org_and_repo}/tar.gz/{branch}"))
299 .send()
300 .await
301 .map_err(|error| BackendError::StartDownloadRepoArchive(org_and_repo.to_owned(), error))?;
302
303 if !tar_gz_response.status().is_success() {
304 return Err(BackendError::DownloadRepoArchive(org_and_repo.to_owned()));
305 }
306
307 let tar_gz_stream = tar_gz_response
308 .bytes_stream()
309 .map(|result| result.map_err(|error| IoError::new(IoErrorKind::Other, error)));
310
311 let tar_gz_reader = StreamReader::new(tar_gz_stream);
312 let tar = GzipDecoder::new(tar_gz_reader);
313 let archive = Archive::new(tar.compat());
314
315 let mut entries = archive.entries().map_err(|_| BackendError::ReadArchiveEntries)?;
316
317 while let Some(entry) = entries.next().await {
318 let mut entry = entry.map_err(BackendError::ExtractArchiveEntry)?;
319
320 if entry
321 .path()
322 .ok()
323 .filter(|path| path.starts_with(&template_path))
324 .is_some()
325 {
326 entry.unpack_in(".").await.map_err(BackendError::ExtractArchiveEntry)?;
327 }
328 }
329
330 if !template_path.exists() {
331 return Err(BackendError::TemplateNotFound);
332 }
333
334 let project_folder = grafbase_path.parent().expect("must exist");
335
336 let named_project = !project_folder.exists();
337
338 if named_project {
339 tokio::fs::create_dir(project_folder)
340 .await
341 .map_err(BackendError::CreateProjectDirectory)?;
342 }
343
344 let rename_result = tokio::fs::rename(template_path, &grafbase_path)
345 .await
346 .map_err(BackendError::MoveExtractedFiles);
347
348 if rename_result.is_err() {
349 tokio::fs::remove_dir_all(project_folder)
350 .await
351 .map_err(BackendError::CleanExtractedFiles)?;
352 }
353
354 rename_result
355}
356
357pub fn reset() -> Result<(), BackendError> {
365 let environment = Environment::get();
366
367 fs::remove_dir_all(&environment.database_directory_path).map_err(BackendError::DeleteDatabaseDirectory)?;
368
369 Ok(())
370}