import functools
from typing import Callable, Optional
import logging
import threading
from briefcase.integrations.lakefs.context import versioned_context
logger = logging.getLogger(__name__)
_thread_local = threading.local()
def versioned(
repository: str,
branch: str = "main",
commit: str = "latest",
client_param: str = "versioned_client"
):
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
briefcase_client = _get_briefcase_client(kwargs)
if briefcase_client is None:
logger.warning(
"versioned decorator requires a Briefcase client. "
"Either pass 'briefcase_client' as kwarg or use within "
"a Briefcase instrumented context. Continuing without versioning."
)
return func(*args, **kwargs)
with versioned_context(
briefcase_client,
repository,
branch,
commit
) as versioned_client:
kwargs[client_param] = versioned_client
return func(*args, **kwargs)
return wrapper
return decorator
lakefs_versioned = versioned
def _get_briefcase_client(kwargs: dict):
if "briefcase_client" in kwargs:
return kwargs.pop("briefcase_client")
try:
return getattr(_thread_local, 'briefcase_client', None)
except:
return None
def set_briefcase_client(client):
_thread_local.briefcase_client = client
def clear_briefcase_client():
if hasattr(_thread_local, 'briefcase_client'):
delattr(_thread_local, 'briefcase_client')